次は,テストデータの認識をしてみよう.モデルの評価では認識精度が75%でしたが(データの作成時に並び順をランダムにシャッフルしているので,結果は必ずしも一致しないことに注意してください),どのような数字の認識が成功して,どのような数字の認識に失敗しているのかを確認しよう.
テストデータを認識させる (11-predict-test.py)
import numpy as np
from keras.utils.np_utils import to_categorical
from keras.models import Sequential
from keras.layers import Dense
def print_train_test(idx):
print('認識結果 :', np.argmax(pred_test[idx]))
print('正解ラベル:', np.argmax(y_test[idx]))
i = 1
for x in x_test[idx]:
if (x == 1):
print("+ ", end="")
else:
print(" ", end="")
if i % 15 == 0:
print("")
i += 1
# ファイルを開いて読み込む
x_train = np.load('train_X_data.npy')
y_train = np.load('train_Y_data.npy')
x_test = np.load('test_X_data.npy')
y_test = np.load('test_Y_data.npy')
# 正解ラベルを one-hot-encoding にする
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
# モデルを作る
model = Sequential()
model.add(Dense(128, activation='relu', input_dim=225)) # input_dim = 15 x 15 = 225
model.add(Dense(10, activation='softmax'))
# モデルをコンパイルする
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.summary()
# 学習してみよう(このコードだけで,学習状況も表示される)
model.fit(x_train, y_train,
batch_size=20,
epochs=30,
verbose=1)
# モデルを評価する(テストデータを使う)
score = model.evaluate(x_test, y_test)
print(score)
print(model.metrics_names)
print(model.metrics_names[0], " : ", score[0])
print(model.metrics_names[1], " : ", score[1])
# 予測してみよう
pred_test = model.predict(x_test)
while True:
print('---------------------')
print('予測結果を表示したいテストデータの番号(0 から', x_test.shape[0]-1 ,'まで)を入力してください(-1で終了します):', end="")
str_idx = input()
# 空の場合の処理
if str_idx == "":
print('入力してください')
continue
# 入力した文字列を整数に変換するが,変換できない場合のために例外処理が必要
try:
idx = int(str_idx)
except ValueError:
print('エラー:数字以外の文字は入力できません')
continue
# 終了判定
if idx == -1:
break
if idx < 0 or idx > x_test.shape[0]-1:
print('正しい値を入れてください')
continue
print_train_test(idx)
print('------ 終了しました ------')
Using TensorFlow backend. Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_1 (Dense) (None, 128) 28928 _________________________________________________________________ dense_2 (Dense) (None, 10) 1290 ================================================================= Total params: 30,218 Trainable params: 30,218 Non-trainable params: 0 _________________________________________________________________ Epoch 1/30 80/80 [==============================] - 0s 885us/step - loss: 2.2974 - accuracy: 0.2375 Epoch 2/30 80/80 [==============================] - 0s 62us/step - loss: 1.7375 - accuracy: 0.5000 ...(中略)... Epoch 30/30 80/80 [==============================] - 0s 224us/step - loss: 0.0309 - accuracy: 1.0000 20/20 [==============================] - 0s 947us/step [0.7316001653671265, 0.75] ['loss', 'accuracy'] loss : 0.7316001653671265 accuracy : 0.75 --------------------- 予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):1 ⏎ 認識結果 : 4 ←←←←← 認識失敗 正解ラベル: 9 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + --------------------- 予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):5 ⏎ 認識結果 : 6 正解ラベル: 6 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + --------------------- 予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):9 ⏎ 認識結果 : 0 正解ラベル: 0 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + --------------------- 予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):10 ⏎ 認識結果 : 6 正解ラベル: 6 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + --------------------- 予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):11 ⏎ 認識結果 : 5 ←←←←← 認識失敗 正解ラベル: 6 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + --------------------- 予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):12 ⏎ 認識結果 : 5 正解ラベル: 5 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + --------------------- 予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):13 ⏎ 認識結果 : 0 正解ラベル: 0 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + --------------------- 予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):14 ⏎ 認識結果 : 8 正解ラベル: 8 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + --------------------- 予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):15 ⏎ 認識結果 : 0 正解ラベル: 0 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + --------------------- 予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):16 ⏎ 認識結果 : 1 正解ラベル: 1 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + --------------------- 予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):17 ⏎ 認識結果 : 6 正解ラベル: 6 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + --------------------- 予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):18 ⏎ 認識結果 : 7 正解ラベル: 7 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + --------------------- 予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):19 ⏎ 認識結果 : 7 正解ラベル: 7 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + --------------------- 予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):-1 ⏎ ------ 終了しました ------
いくつかの数字の認識に失敗していることがわかりました.
ここでは,わずか80個のデータでモデルの学習を行いました.実際にはもっと多くのデータを使って学習する必要があるでしょう.データの数を増やすとともに,中間層のニューロン数(現在は128)や学習時のバッチサイズ(現在は20)を色々変更して試してみると良いでしょう.また,畳み込みやドロップアウトなどを追加したり,学習画像の水増しなども行ってみると良いでしょう.