ResNet 구현하기
개요
- 딥러닝의 학습 기법 중 하나이다. Layer가 깊어질 수록 많은 feature를 포착할 수 있지만 gradient vanishing 문제가 심각해진다.
- ResNet은 해당 문제를 해결하기 위해 residual learning을 도입하였다.
- 아래 그림은 layer의 수가 늘어나면 iteration에 따라 accuracy가 얼마나 가파르게 줄어드는지를 보여준다.

ResNet의 원리
- Input X를 hidden layer에 통과시켜서 F(x)를 얻는 과정은 기본적인 DNN과 동일하지만, 그 후에 H(x)를 구하기 위해서 H(x) = F(x) + x라는 residual learning block을 형성한다.

- F(x) + x는 shortcut connections를 통해 feedforward neural network로 형성된다. Shortcut connection은 하나 이상의 layer를 건너뛰는 기법으로, 본 논문에서는 identity mapping을 형성하기 위해 사용된다.
- Shortcut connection기법은 추가적인 파라미터나 computational complexity의 증가 없이 진행 가능하다.
import numpy as np
import pandas as pd
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
from tensorflow.keras.datasets import cifar100
from tensorflow.keras.utils import plot_model
from tensorflow.keras import regularizers, optimizers
import matplotlib.pyplot as plt
from tensorflow.keras.utils import to_categorical
(x_train, y_train), (x_test, y_test) = cifar100.load_data(
label_mode = 'fine'
)
class_names = [
'apple', # id 0
'aquarium_fish',
'baby',
'bear',
'beaver',
'bed',
'bee',
'beetle',
'bicycle',
'bottle',
'bowl',
'boy',
'bridge',
'bus',
'butterfly',
'camel',
'can',
'castle',
'caterpillar',
'cattle',
'chair',
'chimpanzee',
'clock',
'cloud',
'cockroach',
'couch',
'crab',
'crocodile',
'cup',
'dinosaur',
'dolphin',
'elephant',
'flatfish',
'forest',
'fox',
'girl',
'hamster',
'house',
'kangaroo',
'computer_keyboard',
'lamp',
'lawn_mower',
'leopard',
'lion',
'lizard',
'lobster',
'man',
'maple_tree',
'motorcycle',
'mountain',
'mouse',
'mushroom',
'oak_tree',
'orange',
'orchid',
'otter',
'palm_tree',
'pear',
'pickup_truck',
'pine_tree',
'plain',
'plate',
'poppy',
'porcupine',
'possum',
'rabbit',
'raccoon',
'ray',
'road',
'rocket',
'rose',
'sea',
'seal',
'shark',
'shrew',
'skunk',
'skyscraper',
'snail',
'snake',
'spider',
'squirrel',
'streetcar',
'sunflower',
'sweet_pepper',
'table',
'tank',
'telephone',
'television',
'tiger',
'tractor',
'train',
'trout',
'tulip',
'turtle',
'wardrobe',
'whale',
'willow_tree',
'wolf',
'woman',
'worm',
]
plt.imshow(x_train[0])
<matplotlib.image.AxesImage at 0x21982687d00>

ResNet에서는 넣는 input 이미지를 RGB에서 BGR로 바꾼다.
res_x_train = preprocess_input(
x_train, data_format=None
)
res_x_test = preprocess_input(
x_test, data_format=None
)
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
import tensorflow as tf
IMG_SIZE = 224
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)
input_tensor = tf.keras.Input(shape=(32,32,3))
resized_image = tf.keras.layers.Lambda(lambda image: tf.image.resize(image, (IMG_SIZE, IMG_SIZE)))(input_tensor)
model = ResNet50(
include_top=True,
weights=None,
input_tensor=resized_image,
input_shape=None,
pooling=None,
classes = 100
)
model.compile(loss='categorical_crossentropy',
optimizer=optimizers.Adam(learning_rate=0.0001),
metrics=["accuracy", "top_k_categorical_accuracy"])
#model.summary()
plot_model(model, show_shapes=True)

from tensorflow.keras.callbacks import ModelCheckpoint
Res_mc = ModelCheckpoint('D:/Projects/딥러닝 체크포인트/ResNet for Blog',
monitor='val_accuracy',
mode='max',
verbose=1,
save_best_only=True)
Res_history = model.fit(res_x_train,
y_train,
validation_split=0.3,
epochs=10,
callbacks=[Res_mc])
Epoch 1/10
1094/1094 [==============================] - ETA: 0s - loss: 3.2561 - accuracy: 0.2103 - top_k_categorical_accuracy: 0.4924
Epoch 1: val_accuracy improved from -inf to 0.19740, saving model to D:/Projects/딥러닝 체크포인트\ResNet for Blog
INFO:tensorflow:Assets written to: D:/Projects/딥러닝 체크포인트\ResNet for Blog\assets
1094/1094 [==============================] - 201s 184ms/step - loss: 3.2561 - accuracy: 0.2103 - top_k_categorical_accuracy: 0.4924 - val_loss: 3.4077 - val_accuracy: 0.1974 - val_top_k_categorical_accuracy: 0.4837
Epoch 2/10
1094/1094 [==============================] - ETA: 0s - loss: 2.7923 - accuracy: 0.2966 - top_k_categorical_accuracy: 0.6038
Epoch 2: val_accuracy improved from 0.19740 to 0.26633, saving model to D:/Projects/딥러닝 체크포인트\ResNet for Blog
INFO:tensorflow:Assets written to: D:/Projects/딥러닝 체크포인트\ResNet for Blog\assets
1094/1094 [==============================] - 200s 183ms/step - loss: 2.7923 - accuracy: 0.2966 - top_k_categorical_accuracy: 0.6038 - val_loss: 3.0937 - val_accuracy: 0.2663 - val_top_k_categorical_accuracy: 0.5635
Epoch 3/10
1094/1094 [==============================] - ETA: 0s - loss: 2.3841 - accuracy: 0.3822 - top_k_categorical_accuracy: 0.6955
Epoch 3: val_accuracy improved from 0.26633 to 0.32720, saving model to D:/Projects/딥러닝 체크포인트\ResNet for Blog
INFO:tensorflow:Assets written to: D:/Projects/딥러닝 체크포인트\ResNet for Blog\assets
1094/1094 [==============================] - 202s 184ms/step - loss: 2.3841 - accuracy: 0.3822 - top_k_categorical_accuracy: 0.6955 - val_loss: 2.7444 - val_accuracy: 0.3272 - val_top_k_categorical_accuracy: 0.6377
Epoch 4/10
1094/1094 [==============================] - ETA: 0s - loss: 2.0407 - accuracy: 0.4564 - top_k_categorical_accuracy: 0.7679
Epoch 4: val_accuracy did not improve from 0.32720
1094/1094 [==============================] - 186s 170ms/step - loss: 2.0407 - accuracy: 0.4564 - top_k_categorical_accuracy: 0.7679 - val_loss: 3.0187 - val_accuracy: 0.3031 - val_top_k_categorical_accuracy: 0.6131
Epoch 5/10
1094/1094 [==============================] - ETA: 0s - loss: 1.7143 - accuracy: 0.5301 - top_k_categorical_accuracy: 0.8283
Epoch 5: val_accuracy improved from 0.32720 to 0.36320, saving model to D:/Projects/딥러닝 체크포인트\ResNet for Blog
INFO:tensorflow:Assets written to: D:/Projects/딥러닝 체크포인트\ResNet for Blog\assets
1094/1094 [==============================] - 201s 184ms/step - loss: 1.7143 - accuracy: 0.5301 - top_k_categorical_accuracy: 0.8283 - val_loss: 2.6407 - val_accuracy: 0.3632 - val_top_k_categorical_accuracy: 0.6792
Epoch 6/10
1094/1094 [==============================] - ETA: 0s - loss: 1.3972 - accuracy: 0.6087 - top_k_categorical_accuracy: 0.8786
Epoch 6: val_accuracy improved from 0.36320 to 0.36540, saving model to D:/Projects/딥러닝 체크포인트\ResNet for Blog
INFO:tensorflow:Assets written to: D:/Projects/딥러닝 체크포인트\ResNet for Blog\assets
1094/1094 [==============================] - 202s 184ms/step - loss: 1.3972 - accuracy: 0.6087 - top_k_categorical_accuracy: 0.8786 - val_loss: 2.7990 - val_accuracy: 0.3654 - val_top_k_categorical_accuracy: 0.6744
Epoch 7/10
1094/1094 [==============================] - ETA: 0s - loss: 1.0918 - accuracy: 0.6833 - top_k_categorical_accuracy: 0.9255
Epoch 7: val_accuracy improved from 0.36540 to 0.37107, saving model to D:/Projects/딥러닝 체크포인트\ResNet for Blog
INFO:tensorflow:Assets written to: D:/Projects/딥러닝 체크포인트\ResNet for Blog\assets
1094/1094 [==============================] - 196s 179ms/step - loss: 1.0918 - accuracy: 0.6833 - top_k_categorical_accuracy: 0.9255 - val_loss: 2.9228 - val_accuracy: 0.3711 - val_top_k_categorical_accuracy: 0.6776
Epoch 8/10
1094/1094 [==============================] - ETA: 0s - loss: 0.7851 - accuracy: 0.7667 - top_k_categorical_accuracy: 0.9631
Epoch 8: val_accuracy did not improve from 0.37107
1094/1094 [==============================] - 186s 170ms/step - loss: 0.7851 - accuracy: 0.7667 - top_k_categorical_accuracy: 0.9631 - val_loss: 3.2557 - val_accuracy: 0.3613 - val_top_k_categorical_accuracy: 0.6571
Epoch 9/10
1094/1094 [==============================] - ETA: 0s - loss: 0.5503 - accuracy: 0.8368 - top_k_categorical_accuracy: 0.9821
Epoch 9: val_accuracy improved from 0.37107 to 0.39353, saving model to D:/Projects/딥러닝 체크포인트\ResNet for Blog
INFO:tensorflow:Assets written to: D:/Projects/딥러닝 체크포인트\ResNet for Blog\assets
1094/1094 [==============================] - 202s 185ms/step - loss: 0.5503 - accuracy: 0.8368 - top_k_categorical_accuracy: 0.9821 - val_loss: 3.0606 - val_accuracy: 0.3935 - val_top_k_categorical_accuracy: 0.6821
Epoch 10/10
1094/1094 [==============================] - ETA: 0s - loss: 0.3801 - accuracy: 0.8893 - top_k_categorical_accuracy: 0.9925
Epoch 10: val_accuracy did not improve from 0.39353
1094/1094 [==============================] - 186s 170ms/step - loss: 0.3801 - accuracy: 0.8893 - top_k_categorical_accuracy: 0.9925 - val_loss: 3.3768 - val_accuracy: 0.3781 - val_top_k_categorical_accuracy: 0.6801
loss, acc, k_acc = model.evaluate(res_x_test, y_test, verbose=2)
print('모델의 정확도 (top-1-error): {:5.2f}%'.format(100*acc))
print('모델의 정확도 (top-5-error): {:5.2f}%'.format(100*k_acc))
313/313 - 15s - loss: 3.3377 - accuracy: 0.3832 - top_k_categorical_accuracy: 0.6860 - 15s/epoch - 47ms/step
모델의 정확도 (top-1-error): 38.32%
모델의 정확도 (top-5-error): 68.60%
pd.DataFrame(Res_history.history).plot()
<AxesSubplot:>

많은 epoch를 돌리지 못하고, 이미지 셋에 제대로 transfer learning하지 못했기 때문에 val_accuracy와 val_error가 좋아지질 않는다. 이는 어느정도 조정과 많은 epoch로 해결할 수 있을 것이다. Reproduce가 목표이므로 여기까지만 하도록 하겠다. (담에 시간 나면 더 나은 결과를 갖도록 수정하겠다.)