3.5 에러분석
적절한 모델을 찾았다면 이 모델의 성능을 향상시킬 방법을 찾아야 한다. 에러를 분석하여 통찰을 얻을 수 있다.
<오차행렬>
cross_val_predict()함수로 모든 클래스에 대한 예측값을 만들어서 오차 행렬을 만들어 본다.
(이진 분류기에서와 마찬가지로 행이 실제 값이고 열이 예측값이다.)
(i행 j열 즉 $C_i,j$는 실제 클래스가 i인 것을 j로 예측한 것의 개수를 의미한다.)
y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
conf_mx = confusion_matrix(y_train, y_train_pred)
conf_mx
#array([[5576, 0, 21, 6, 9, 43, 37, 6, 224, 1],
# [ 0, 6398, 38, 23, 4, 44, 4, 8, 213, 10],
# [ 26, 27, 5242, 90, 71, 26, 62, 36, 371, 7],
# [ 24, 17, 117, 5220, 2, 208, 28, 40, 405, 70],
# [ 12, 14, 48, 10, 5192, 10, 36, 26, 330, 164],
# [ 28, 15, 33, 166, 55, 4437, 76, 14, 538, 59],
# [ 30, 14, 41, 2, 43, 95, 5560, 4, 128, 1],
# [ 21, 9, 52, 27, 51, 12, 3, 5693, 188, 209],
# [ 17, 63, 46, 90, 3, 125, 25, 10, 5429, 43],
# [ 23, 18, 31, 66, 116, 32, 1, 179, 377, 5106]],
# dtype=int64)
오차행렬을 맷플롯립의 matshow함수로 표현하면 보기에 편하다.
90퍼센트에 달한느 분류기답게 대체로 잘 예측하곤 있다. (클래스 i를 i라고 분류한 것의 개수가 월등히 많다)
클래스 5를 5라고 예측한 것의 개수를 의미하는$C_5,5$가 비교적 어두운 것을 확인할 수 있다.
mnist 셋은 클래스마다 같은 개수의 셋이 있는것이 아니라 불균형하므로 5가 다른 클래스보다 개수가 많다면 상대적으로 나쁘게 보일 수 있다. 따라서 에러 비율을 비교해야 한다.
제대로 예측한 것을 제외한 에러를 뚜렷하게 보기위해 대각행렬을 0으로 바꿔서 시각화 해본다.
숫자들을 8로 잘못 예측한 경우가 굉장히 많는 것을 알 수 있다. 특히 6를 8로 잘못 예측한 것이 가장 많다.
(동심원의 수와 같은 특성을 뽑아낼 수도 있지만, 범위에서 벗어난다.)
또한 5를 3으로, 3을 5로 잘못 예측한 것이 많다. 3과 5의 샘플을 그려본다.
def plot_digits(instances, images_per_row):
images_per_row = min(len(instances), images_per_row)
images = [instance.reshape(28,28) for instance in instances]
n_rows = (len(instances) - 1) // images_per_row + 1
row_images = []
n_empty = n_rows * images_per_row - len(instances)
images.append(np.zeros((28, 28 * n_empty)))
for row in range(n_rows):
rimages = images[row * images_per_row : (row + 1) * images_per_row]
row_images.append(np.concatenate(rimages, axis=1))
image = np.concatenate(row_images, axis=0)
plt.imshow(image, cmap = mpl.cm.binary)
plt.axis("off")
cl_a, cl_b = 3,5
#X_ij : i를 j로 예상한 것
X_aa = X_train[(y_train==cl_a) & (y_train_pred==cl_a)]
X_ab = X_train[(y_train==cl_a) & (y_train_pred==cl_b)]
X_ba = X_train[(y_train==cl_b) & (y_train_pred==cl_a)]
X_bb = X_train[(y_train==cl_b) & (y_train_pred==cl_b)]
plt.figure(figsize=(8,8))
plt.subplot(221); plot_digits(X_aa[:25], images_per_row=5)
plt.subplot(222); plot_digits(X_ab[:25], images_per_row=5)
plt.subplot(223); plot_digits(X_ba[:25], images_per_row=5)
plt.subplot(224); plot_digits(X_bb[:25], images_per_row=5)
plt.show()
SGDClassifier은 선형 분류기이기 때문에 픽셀에 가중치를 할당하고 새로운 이미지에 대해 단순히 픽셀 강도의 가중치 합을 클래스의 점수로 계산한다. 따라서 3과 5는 몇 개의 픽셀만 다르기 때문에 모델이 쉽게 혼동하는 것이다.
이미지를 중앙에 위치시키고 회전되지 않도록 전처히한다면 에러를 줄일 수 있다고 한다.
'ML&DATA > 핸즈온 머신러닝' 카테고리의 다른 글
4 - 선형 회귀 (정규방정식) (0) | 2020.09.04 |
---|---|
3 - 다중 레이블 분류, 다중 출력 분류 (2) | 2020.08.27 |
3 - 다중분류 (0) | 2020.08.27 |
3 - 분류기의 성능 측정 (0) | 2020.08.25 |
개요 (2) | 2020.07.16 |