본문 바로가기
ML&DATA/핸즈온 머신러닝

3 - 에러분석

by sun__ 2020. 8. 27.

3.5 에러분석

적절한 모델을 찾았다면 이 모델의 성능을 향상시킬 방법을 찾아야 한다. 에러를 분석하여 통찰을 얻을 수 있다.

 

 

<오차행렬>

cross_val_predict()함수로 모든 클래스에 대한 예측값을 만들어서 오차 행렬을 만들어 본다.

(이진 분류기에서와 마찬가지로 행이 실제 값이고 열이 예측값이다.)

(i행 j열 즉 $C_i,j$는 실제 클래스가 i인 것을 j로 예측한 것의 개수를 의미한다.)

y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
conf_mx = confusion_matrix(y_train, y_train_pred)
conf_mx
#array([[5576,    0,   21,    6,    9,   43,   37,    6,  224,    1],
#       [   0, 6398,   38,   23,    4,   44,    4,    8,  213,   10],
#       [  26,   27, 5242,   90,   71,   26,   62,   36,  371,    7],
#       [  24,   17,  117, 5220,    2,  208,   28,   40,  405,   70],
#       [  12,   14,   48,   10, 5192,   10,   36,   26,  330,  164],
#       [  28,   15,   33,  166,   55, 4437,   76,   14,  538,   59],
#       [  30,   14,   41,    2,   43,   95, 5560,    4,  128,    1],
#       [  21,    9,   52,   27,   51,   12,    3, 5693,  188,  209],
#       [  17,   63,   46,   90,    3,  125,   25,   10, 5429,   43],
#       [  23,   18,   31,   66,  116,   32,    1,  179,  377, 5106]],
#      dtype=int64)

 

오차행렬을 맷플롯립의 matshow함수로 표현하면 보기에 편하다.

90퍼센트에 달한느 분류기답게 대체로 잘 예측하곤 있다. (클래스 i를 i라고 분류한 것의 개수가 월등히 많다)

클래스 5를 5라고 예측한 것의 개수를 의미하는$C_5,5$가 비교적 어두운 것을 확인할 수 있다.

 

mnist 셋은 클래스마다 같은 개수의 셋이 있는것이 아니라 불균형하므로 5가 다른 클래스보다 개수가 많다면 상대적으로 나쁘게 보일 수 있다. 따라서 에러 비율을 비교해야 한다. 

5,5의 색이 옅어졌다.

 

제대로 예측한 것을 제외한 에러를 뚜렷하게 보기위해 대각행렬을 0으로 바꿔서 시각화 해본다.

 

숫자들을 8로 잘못 예측한 경우가 굉장히 많는 것을 알 수 있다. 특히 6를 8로 잘못 예측한 것이 가장 많다. 

(동심원의 수와 같은 특성을 뽑아낼 수도 있지만, 범위에서 벗어난다.)

 

또한 5를 3으로, 3을 5로 잘못 예측한 것이 많다. 3과 5의 샘플을 그려본다.

def plot_digits(instances, images_per_row):
    images_per_row = min(len(instances), images_per_row)
    images = [instance.reshape(28,28) for instance in instances]
    n_rows = (len(instances) - 1) // images_per_row + 1
    row_images = []
    n_empty = n_rows * images_per_row - len(instances)
    images.append(np.zeros((28, 28 * n_empty)))
    for row in range(n_rows):
        rimages = images[row * images_per_row : (row + 1) * images_per_row]
        row_images.append(np.concatenate(rimages, axis=1))
    image = np.concatenate(row_images, axis=0)
    plt.imshow(image, cmap = mpl.cm.binary)
    plt.axis("off")

cl_a, cl_b = 3,5
#X_ij : i를 j로 예상한 것
X_aa = X_train[(y_train==cl_a) & (y_train_pred==cl_a)]
X_ab = X_train[(y_train==cl_a) & (y_train_pred==cl_b)]
X_ba = X_train[(y_train==cl_b) & (y_train_pred==cl_a)]
X_bb = X_train[(y_train==cl_b) & (y_train_pred==cl_b)]

plt.figure(figsize=(8,8))
plt.subplot(221); plot_digits(X_aa[:25], images_per_row=5)
plt.subplot(222); plot_digits(X_ab[:25], images_per_row=5)
plt.subplot(223); plot_digits(X_ba[:25], images_per_row=5)
plt.subplot(224); plot_digits(X_bb[:25], images_per_row=5)
plt.show()

SGDClassifier은 선형 분류기이기 때문에 픽셀에 가중치를 할당하고 새로운 이미지에 대해 단순히 픽셀 강도의 가중치 합을 클래스의 점수로 계산한다. 따라서 3과 5는 몇 개의 픽셀만 다르기 때문에 모델이 쉽게 혼동하는 것이다.

 

 이미지를 중앙에 위치시키고 회전되지 않도록 전처히한다면 에러를 줄일 수 있다고 한다.

'ML&DATA > 핸즈온 머신러닝' 카테고리의 다른 글

4 - 선형 회귀 (정규방정식)  (0) 2020.09.04
3 - 다중 레이블 분류, 다중 출력 분류  (2) 2020.08.27
3 - 다중분류  (0) 2020.08.27
3 - 분류기의 성능 측정  (0) 2020.08.25
개요  (2) 2020.07.16