본문 바로가기
머신러닝/Linear Regression

Cost function - Boston 집값 예측

by 미생22 2024. 5. 5.
728x90

Iris와 마찬가지로 sklearn의 datasets에는 Boston 집값예측 데이터가 들어있습니다. 이 데이터 세트는 Barnegie Mellon University에서 유지관리 중이며 1978년에 만들어졌습니다.

 

보스턴 주택 가격 데이터는 회귀문제를 다루는 많은 머신러닝 논문에서 사용하고 있습니다.

from sklearn.datasets import load_boston

boston = load_boston()
print(boston.DESCR)

DESCR를 살펴보면 다음과 같이 컬럼이 나옵니다.

대체 boston이 어떻게 생긴걸까요?

 

 

아 dict형태로 data에 데이터가 들어있고 target에 label데이터인 price가 들어있고, feature_names에 data의 컬럼이 들어있네요.

 

dict 형태므로 keys()를 확인해보겠습니다.

 

이제 데이터프레임으로 만들어보겠습니다.

데이터 파악을 위해 pandas로 정리해보겠습니다.

 

import pandas as pd

boston_pd = pd.DataFrame(boston.data, columns=boston.feature_names)
boston_pd['PRICE'] = boston.target #label이 됩니다.

boston_pd.head()

 

이후 label이 될 price를 컬럼을 따로 만들어줍니다.

 

일단 각 컬럼의 의미를 알아야 하는데요

 

논문에 의하면 다음과 같습니다.

이제 histogram으로 price 분포를 살펴보겠습니다.

histogram은 plotly.express의 모듈을 사용해보겠습니다.

 

import plotly.express as px

fig = px.histogram(boston_pd, x='PRICE')
fig.show()

대체로 정규분포의 모양으로 보이긴 하는데... 맨 뒤에 특이한 데이터 몇개가 있네요.

어떤 feature가 price에 큰 영향을 주는지 한번 상관계수를 확인해보겠습니다.

 

pandas dataframd이 갖고있는 corr()함수를 써서 알아볼 것이고, 소수점 첫번째자리까지 반올림하겠습니다.

import matplotlib.pyplot as plt
import seaborn as sns

corr_mat = boston_pd.corr().round(1)
corr_mat

알아보기가 힘드네요.. 아까 plt와 sns를 import한 이유입니다. heatmap이 생각나시죠? 그려보겠습니다.

 

sns.heatmap(data = corr_mat, annot=True, cmap='bwr')

annot으로 수치를 넣는다는 뜻이고 cmap은 colormap 종류입니다.

price와 방의 수 (RM), 저소득층 인구 (LSTAT)와 높은 상관관계가 보입니다.

방의 수가 많을수록, 저소득층 인구가 적을수록 가격이 높을 수 있다고 하네요.

RM과 LSTAT과 PRICE의 관계에 대해 좀 더 관찰해보겠습니다.
sns의 regplot을 통해 알아보겠습니다. lmplot이랑 뭐가 다른지는 모르겠네요...

(https://jehyunlee.github.io/2022/06/06/Python-DS-103-snsreglmplot/) 이 글을 추천합니다.

 

sns.set_style('darkgrid')
sns.set(rc={'figure.figsize':(12, 6)})
fig, ax = plt.subplots(ncols=2) #(1,2)말고 이 방법도 씁니다
sns.regplot(x='RM', y='PRICE', data=boston_pd, ax=ax[0])
sns.regplot(x='LSTAT', y='PRICE', data=boston_pd, ax=ax[1])

저소득층 인구가 낮을수록, 방의 갯수가 많을수록 집값이 높아진다?

이 가설은 문제가 없을까요? 그러나 왼쪽 그래프를 보면 맨 위에 모여있는 몇개의 데이터가 있습니다. 히스토그램의 오른쪽에 모여있던 그래프네요. 

 

그런데 저소득층이 모여살아서 집값이 낮은게 아닐까요?

집값이 낮은 요인이 그 집주변에 소득수준이 낮은 사람들이 살아서가 아닐까요?

집값을 예측하는 모델에 있어서 필요한 특성인지 아닌지에 대해서 생각해봐야 할 것 같다는 겁니다.

 

이런 특성을 계속 이해하려고 노력해야합니다...

 

일단 머신러닝을 위해 데이터를 나눠보겠습니다.

from sklearn.model_selection import train_test_split

X = boston_pd.drop('PRICE', axis=1)
y = boston_pd['PRICE']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=13)

이제 Linear Regression을 사용하겠습니다.

 

from sklearn.linear_model import LinearRegression

reg = LinearRegression()
reg.fit(X_train, y_train)

모델이 제대로 만들어졌는지 평가하기 위해 RMS(제곱평균제곱근; root mean square)를 사용할겁니다.

 

import numpy as np
from sklearn.metrics import mean_squared_error

pred_tr = reg.predict(X_train) #X_train 데이터에 대해 예측해라
pred_test = reg.predict(X_test) #X_test 데이터에 대해 예측해라

rmse_tr = (np.sqrt(mean_squared_error(y_train, pred_tr))) #MSE가 아닌 RMS를 사용할 것이므로 sqrt를 시켜준다.
rmse_test = (np.sqrt(mean_squared_error(y_test, pred_test)))
print('RMSE train : ', rmse_tr)
print('RMSE test : ', rmse_test)

성능을 확인해보겠습니다.
회귀 문제에서 성능을 확인하는 방법은 대체로 scatter가 있습니다.

plt.scatter(y_test, pred_test) #참값과 예측값에 대한 scatter 그래프를 그립니다.
만약 일치할수록 1:1 대각선상에 데이터가 모여있어야합니다.

plt.xlabel('Real ($1000)')
plt.ylabel('Predict Prices')
plt.plot([0,50], [0,50], 'r') #기준선을 그려줍니다.
plt.show()

오른쪽 끝에 제대로 모델에 의한 것보다 actual 값이 높았던 부분이 있다고합니다. 
아까 저소득층에 의한 데이터는 치트키 같은 점이 있었기 때문에 빼고 테스트를 진행해보고자 합니다.

 

X = boston_pd.drop(['PRICE', 'LSTAT'], axis=1)
y = boston_pd['PRICE']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=13)
from sklearn.linear_model import LinearRegression

reg = LinearRegression()
reg.fit(X_train, y_train)
pred_tr = reg.predict(X_train)
pred_test = reg.predict(X_test)

from sklearn.metrics import mean_squared_error

rmse_tr = np.sqrt(mean_squared_error(y_train, pred_tr)) #참값, 예측값 순서로 넣는다
rmse_test = np.sqrt(mean_squared_error(y_test, pred_test))

print('RMSE train : ', rmse_tr)
print('RMSE test : ', rmse_test)

 

오잉 아까보다 에러가 커졌네요, 그래도 그래프로 그려봐야죠?

plt.scatter(x=y_test, y=pred_test)
plt.plot([0,50], [0,50], 'r')
plt.show()

네 저 컬럼을 빼도 되는군요.

728x90