안녕하세요!
오늘은 딥러닝계의 Hello World와 같은 MNIST 데이터를 사용하여 신경망에 대해 공부해보겠습니다.
MNIST 데이터셋은 넘파이 배열 형태로 케라스에 저장이 되어 있습니다.
import tensorflow as tf
(train_X,train_Y),(test_X,test_Y) = tf.keras.datasets.mnist.load_data()
train_X와 train_Y로 훈련 세트를 구성하며, test_X와 test_Y로 테스트 세트를 구성합니다.
데이터를 살펴볼까요?
print(train_X.shape) # (60000,28,28)
print(len(train_X),len(test_X)) # (60000,10000)
plt.imshow(train_X[0],cmap='gray')
plt.colorbar()
print(train_Y[0])
먼저 shape로 데이터를 확인해보면 훈련 세트와 테스트 세트가 각각 60000개, 10000개의 데이터로 구성됩니다.
샘플을 확인해 보겠습니다
train_X = train_X / 255.0
test_X = test_X / 255.0
데이터를 0과 1사이의 값을 가지게 하기 위해 정규화를 해줍니다
신경망의 핵심 구성 요소는 Layer입니다.
주어진 문제에 더 의미있는 표현을 입력된 데이터로부터 추출합니다.
from keras.models import Sequential
from keras.layers import Dense,Flatten
from keras.optimizers import Adam
model = Sequential([
Flatten(input_shape=(28,28)),
Dense(units=128,activation='relu'),
Dense(units=10,activation='softmax')])
Sequential 모델은 작성한 레이어를 선형으로 연결해줍니다.
https://keras.io/ko/getting-started/sequential-model-guide/
각각의 Layer들을 만들어줘서 모델의 구조를 작성하는 과정입니다.
신경망이 훈련을 마치기 위해서 compile단계를 필요로 합니다.
여기에서는 3가지가 필요합니다.
- Loss Function : 훈련 데이터에서 신경망의 성능을 측정하는 방법으로 네트워크가 올바른 방향으로 학습될 수 있도록 도움
- Optimizer : 입력된 데이터와 손실 함수를 기반으로 네트워크를 업데이트하는 방법
- 모니터링 지표 : 정확도 등
model.compile(optimizer=Adam(),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
오늘 글에서는 각 방법이 아닌 신경망의 구조만 다루도록 하겠습니다.
model.summary()
history = model.fit(train_X,train_Y,epochs=25)
summary는 모델을 간략하게 표현합니다.
fit함수를 사용하여 훈련 데이터에 모델을 학습시킵니다.
plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.plot(history.history['loss'],'b-',label='loss')
plt.xlabel('Epoch')
plt.legend()
plt.subplot(1,2,2)
plt.plot(history.history['accuracy'],'g-',label='accuracy')
plt.xlabel('Epoch')
plt.ylim(0.7,1)
plt.legend()
model.evaluate(test_X,test_Y)
손실 함수의 변화와 정확도의 변화를 기록하여 그래프로 그려보겠습니다.
epoch를 반복할 수록 손실이 줄어드는 것을 확인할 수 있습니다.
또한 정확도도 상승하는 것을 보실 수 있습니다.
하지만, 위의 훈련 과정에서는 정확도가 0.96이지만, 테스트 세트의 정확도는 0.95가 나왔습니다.
훈련 세트의 정확도보다는 약간 낮죠?
훈련 정확도와 테스트 정확도 사이의 차이는 과대적합 때문입니다! 이는 머신 러닝 모델이 훈련 데이터보다 새로운 데이터에서
성능이 낮아지는 것을 말합니다.
'AI > AI' 카테고리의 다른 글
[AI] CNN 구현 시 고려해야할 사항들 (1) | 2022.12.07 |
---|---|
[AI] 전이학습 | Cifar10 | MobileNetV2 (0) | 2022.11.23 |
[AI] Perceptron (0) | 2022.11.07 |
[AI] K-Nearest Neighbors 직접 구현 (0) | 2022.10.19 |
[OpenCV] 이미지 마스킹 기초 (2) | 2022.10.12 |