프로그래밍/머신러닝

SVM을 사용한 MNIST 데이터 분류

RAVIN 2022. 12. 19. 17:08
  • 과제 목표: 서포트벡터머신 모델을 이용해 MNIST 데이터셋 분류 (0~9까지 멀티클래스 분류)
  • Google Colaboratory 사용

1. 패키지 가져오기

import numpy as np
from keras.datasets import mnist
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from matplotlib.colors import ListedColormap

from sklearn.metrics import plot_confusion_matrix
import matplotlib.pyplot as plt

 

2. MNIST 데이터셋 로드 및 데이터 분할/표준화

(train_images, train_labels), (test_images, test_labels) = mnist.load_data() #train:학습데이터, test: 테스트 데이터

class_index = 0 #클래스 인덱스
X = train_images.reshape(-128*28) / 255 
#reshape 함수를 통해 각 데이터 샘플을 행벡터로, 255로 나누어 줌으로서 픽셀값을 [0,1] 단위로 변환한다.
#-1 은 2차원 데이터(데이터샘플)의 개수를 모르기 때문에 넣어준 것.
#하나의 이미지가 28 by 28 행렬로 이루어져 있기 때문에 28*28 열을 가지게 함. 하나의 행벡터 크기가 1 by (28*28)

X_set = X[train_labels == class_index] # class_index에 해당하는 이미지 부분 가져오기
Y_set = (class_index) * np.ones(X_set.shape[0]) #class_index 클래스의 정답을 class_index로 표기


for i in range(1,10): #클래스 1 ~ 9 도 X_set, Y_set에 추가해준다.
  X_set = np.append(X_set, X[train_labels == i], axis=0)           # 클래스2 숫자를 추가
  Y_set = np.append(Y_set, i * np.ones(X_set.shape[0] - Y_set.shape[0]))  # 클래스2 숫자 정답은 +1로 표기


Y_set = Y_set.astype('int64'#정답값(레이블) 형변환

#학습데이터-테스트데이터 쪼개는 함수
X_train, X_test, y_train, y_test = train_test_split(
    X_set, Y_set, test_size=0.3, random_state=1, stratify=Y_set)
# test_size : 전체 샘플에서 지정한 비율 만큼을 테스트 셋으로 분할
# stratify  : 각 분할 셋의 클래스 분포 비가 지정한 셋과 동일하도록 함(옵션값. 이거 안 넣으면 그냥 무작위 비율로 쪼갠다)


print('y의 레이블 카운트 : ', np.bincount(Y_set)) #bincount: 어떤 종류의 값이 있는지 각각 세준다.
print('y_train의 레이블 카운트 : ', np.bincount(y_train)) 
print('y_test의 레이블 카운트 : ', np.bincount(y_test))

 

3. SVM 객체 생성 및 학습

svm = SVC(kernel='rbf', gamma = 0.3, C=1.0, random_state=1, max_iter = 500#객체 생성. 커널을 선형으로 설정 즉 결정경계가 평면으로 이루어지는 SVM
 
svm.fit(X_train_std, y_train) #SVM 학습함수

 

4. Confusion Matrix 그리기

# import some data to play with
class_names = ['0''1''2''3''4''5''6''7''8''9']

fig, ax = plt.subplots(figsize=(1010))
plot_confusion_matrix(svm, X_test_std, y_test,
                      display_labels=class_names,
                      cmap=plt.cm.Blues, ax=ax, normalize='true')     # true로 하면 확률로 보여줌.normalize='true'
                                            # 안하면 개수로 보여줌

plt.show()

 

 


SVM을 이용한 MNIST 데이터 분류.ipynb
0.07MB