import numpy as np
from sklearn.datasets import fetch_openml
# mnist 784 dataset을 불러오는 코드
mnist = fetch_openml('mnist_784',version=1,cache=True)
X, y = mnist['data'], mnist['target'] #가져 올 때 데이터 형식이 달라지므로 지정
y = y.astype(np.int8)
X, y = mnist['data'], mnist['target']
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)
from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)
sklearn.datasets mnist 데이터를 불러와서 SGD분류모델에 fitting을 시켰더니 아래와 같이 에러가 나왔다.
이 에러는 현재 1의 class로만 분류되어 나타나는 오류로, 2개 이상의 class로 class를 재분류하라는 뜻이다. 이 오류의 원인을 찾아보자.
먼저 model fit 인스턴스에 class가 하나만 들어가있다. 그렇기 때문에 위에서 data를 분류할 때 문제가 있었다는 것이다. 불러와진 mnist 데이터를 살펴보니 데이터가 숫자가 아닌 문자로 되어있고, 뒤에 데이터 셔플할 때 다시 문자로 변환되었기 때문에 숫자로 바꿔줘야 한다. 그래서 아래와 같이 숫자로 변환하는 코드를 삽입해줘야 한다.
해결방법
import numpy as np
from sklearn.datasets import fetch_openml
# mnist 784 dataset을 불러오는 코드
mnist = fetch_openml('mnist_784',version=1,cache=True)
X, y = mnist['data'], mnist['target'] #가져 올 때 데이터 형식이 달라지므로 지정
y = y.astype(np.int8)
X, y = mnist['data'], mnist['target']
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]
y_train = y_train.astype(np.int8) # 해결방법 : 추가코드
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)
from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)
y_train 부분에 y_train.astype(np.int8) 코드를 삽입하여 문자 값을 숫자 값으로 바꿔준다.
728x90
반응형
'기타 정보 > 오류 코드 해결 모음' 카테고리의 다른 글
matplotlib(plt)에서두개의 선 그래프를 겹쳐 그리기 & plt.legend( ) (0) | 2021.02.20 |
---|---|
jvmnotfoundexception(konlpy error) - no jvm shared library file (jvm.dll) found (0) | 2021.02.09 |
Jupyter Notebook 오류 - 이상한 경로 (0) | 2020.08.11 |
아나콘다 프롬프트 InvalidArchiveError (0) | 2020.08.09 |
파이썬 에러 : UnicodeDecodeError (5) | 2020.07.16 |
최근댓글