www.youtube.com/watch?v=6omvN1nuZMc&list=PLJN246lAkhQjX3LOdLVnfdFaCbGouEBeb&index=13
박해선 교수님의 유튜브 강의로 공부했음을 밝힙니다.
선형 회귀 모델을 훈련시키는 두 가지 방법
- 정규방정식
- 경사하강법 GD(배치, 미니배치, 확률적(stochastic))
정규방정식
<실제 모듈 사용>
sklearn.linear_model의 LinearRegression에 해당함.
from sklearn.linear_model import LinearRegression
lin_reg = LinearRegression()
lin_reg.fit(X,y)
lin_reg.intercept_, lin_reg.coef_ #bias와 weight
#(array([4.41766218]), array([[2.74273432]]))
lin_reg.predict(X_new)
#array([[4.41766218],
# [9.90313081]])
<정규방정식>
m : 샘플의 개수
n : 샘플하나의 feature 개수
$x_0$ : 더미 feature임. 행렬곱셈 시 편리함을 위함. 1값을 갖는다.
$\theta$ : 가중치벡터
$\hat{\theta }$ : 비용함수를 최소화하는 가중치벡터
RMSE(root mean square error) : 평균 제곱근 오차. 회귀에서 가장 널리 사용되는 성능 측정 지표.
MSE(mean square error) : 평균 제곱 오차. RMSE가 최소일때는 곧 MSE가 최소일 때이다. MSE가 최소일 때 $\hat{\theta }$를 구하는 것이 더 쉬우므로 사용.
RMSE가 최소가 되는 지점은 MSE의 미분식이 0일 때 이다. 시그마를 없애고 행렬곱으로 표현하여 미분하고 정리하면 정규방정식을 끌어낼 수 있다.
정규방정식이란 다음을 말한다.
$$\hat{\theta } = (X^{T}X)^{-1}X^{T}y$$
<정규방정식으로 회귀 구현>
데이터 X와 라벨 y가 다음과 같이 주어질 때, 정규방정식을 통해서 $\theta_0, \theta_1$을 구해보자
주어진 데이터 X에 더미feature를 추가한 X행렬로 $\hat{\theta}$을 찾는다.
X_b = np.c_[np.ones((100,1)),X] #모든 샘플에 x0=1 추가
# (2,100)(100,2) -> (2,2)
# (2,2)(2,100) -> (2,100)
# (2,100)(100,1) -> (2,1)
theta_best = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y)
#array([[4.41766218],
# [2.74273432]])
++
x=0,2일때 $\hat{\theta}$를 사용한 결과값 y_pred 두 점을 이어서 시각화.
++
LinearRegression은 실제역행렬을 구하지 않고 scipy.linalg.lstsq()함수를 기반으로 한다. 역행렬이 없는 경우는 구할 수 없기 때문.
특잇값 분해(SVD)라 부르는 표준 행렬 분해 기법을 사용해서 유사역행렬을 구한다. 이 때 정규방정식은 다음과 같다.
$$\hat{\theta} = X^{+}y$$
residual은 $(\hat{y}-y)^{2}$을 의미한다. rank는 차원을 의미한다.
theta_best_svd, residuals, rank, s = np.linalg.lstsq(X_b, y , rcond=1e-6)
theta_best_svd
#array([[4.41766218],
# [2.74273432]])
np.linalg.pinv(X_b).dot(y) #유사 역행렬
#array([[4.41766218],
# [2.74273432]])
<시간복잡도>
n은 feature수라는 것을 잊지 말자.
역행렬을 사용하는 방법은 행렬곱의 시간복잡도 $O(n^{2.4})$ 만큼 시간이 걸린다.
SVD방법은 $O(n^{2})$ 시간이 걸린다. SVD방법이 훨씬 빠르다.
정규방정식으로 학습된 선형회귀 모델은 샘플 수에 큰 영향을 받지 않는다. (샘플 수의 선형시간)
'ML&DATA > 핸즈온 머신러닝' 카테고리의 다른 글
4 - 다항회귀, 규제 (2) | 2020.09.09 |
---|---|
4 - 선형회귀 (경사하강법) (0) | 2020.09.04 |
3 - 다중 레이블 분류, 다중 출력 분류 (2) | 2020.08.27 |
3 - 에러분석 (0) | 2020.08.27 |
3 - 다중분류 (0) | 2020.08.27 |