これまで,実行のたびにモデルを構築,データを読み込み,学習を行っていた.いま使用しているデータは小さなデータセットであるので,学習は数秒で完了します.しかしながら,大きなデータセットを使った学習では数時間,数日あるいは数週間を要することも少なくない.このため,モデルと学習済みの重みパラメータを保存しておき,いつでも予測ができるような仕組みが必要です.
まずは,モデルと学習済みの重みパラメータを保存してみよう.
モデルと重みパラメータを保存する (12-save.py)
import numpy as np
from keras.utils import to_categorical
from keras.models import Sequential
from keras.layers import Dense
import json
# ファイルを開いて読み込む
x_train = np.load('train_X_data.npy')
y_train = np.load('train_Y_data.npy')
x_test = np.load('test_X_data.npy')
y_test = np.load('test_Y_data.npy')
# 正解ラベルを one-hot-encoding にする
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
# モデルを作る
model = Sequential()
model.add(Dense(128, activation='relu', input_dim=225)) # input_dim = 15 x 15 = 225
model.add(Dense(10, activation='softmax'))
# モデルをコンパイルする
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.summary()
# 学習する
model.fit(x_train, y_train,
batch_size=20,
epochs=30,
verbose=1)
# モデルの保存
json_string = model.to_json()
open('tegaki-model.json', 'w').write(json_string)
# 重みの保存
hdf5_file = "tegaki-predict.weights.h5"
model.save_weights(hdf5_file)
Using TensorFlow backend. Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_1 (Dense) (None, 128) 28928 _________________________________________________________________ dense_2 (Dense) (None, 10) 1290 ================================================================= Total params: 30,218 Trainable params: 30,218 Non-trainable params: 0 _________________________________________________________________ Epoch 1/30 80/80 [==============================] - 0s 935us/step - loss: 2.3075 - accuracy: 0.1875 Epoch 2/30 80/80 [==============================] - 0s 62us/step - loss: 1.7684 - accuracy: 0.5125 ...(中略)... Epoch 30/30 80/80 [==============================] - 0s 212us/step - loss: 0.0315 - accuracy: 1.0000
これにより,プロジェクトのディレクトリに「tegaki-model.json」と「tegaki-predict.weights.h5」が作成されました.次は,これらのファイルを読み込んでみよう.