Python入門トップページ


手書き数字を認識するAIを作ってみよう : 目次

  1. 画像データの準備と確認
  2. 画像データを読み込んでみよう
  3. 画像データの一覧を読み込んでみよう
  4. 学習データとテストデータを準備する
  5. 保存したデータを開いてみる
  6. モデルを作る
  7. 学習(トレーニング)させてみよう
  8. モデルを評価しよう
  9. 学習データで認識させてみよう(1)
  10. 学習データで認識させてみよう(2)
  11. 学習データで認識させてみよう(3)
  12. テストデータで認識させてみよう
  13. モデルと重みパラメータを保存しよう
  14. 学習済みモデルをロードしよう
  15. 学習済みモデルをロードして,認識してみよう

テストデータで認識させてみよう

次は,テストデータの認識をしてみよう.モデルの評価では認識精度が75%でしたが(データの作成時に並び順をランダムにシャッフルしているので,結果は必ずしも一致しないことに注意してください),どのような数字の認識が成功して,どのような数字の認識に失敗しているのかを確認しよう.

テストデータを認識させる (11-predict-test.py)
import numpy as np
from keras.utils.np_utils import to_categorical
from keras.models import Sequential
from keras.layers import Dense

def print_train_test(idx):
    print('認識結果  :', np.argmax(pred_test[idx]))
    print('正解ラベル:', np.argmax(y_test[idx]))
    i = 1
    for x in x_test[idx]:
        if (x == 1):
            print("+ ", end="")
        else:
            print("  ", end="")
        if i % 15 == 0:
            print("")
        i += 1

# ファイルを開いて読み込む
x_train = np.load('train_X_data.npy')
y_train = np.load('train_Y_data.npy')
x_test = np.load('test_X_data.npy')
y_test = np.load('test_Y_data.npy')

# 正解ラベルを one-hot-encoding にする
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

# モデルを作る
model = Sequential()
model.add(Dense(128, activation='relu', input_dim=225))  # input_dim = 15 x 15 = 225
model.add(Dense(10, activation='softmax'))

# モデルをコンパイルする
model.compile(optimizer='rmsprop',
          loss='categorical_crossentropy',
          metrics=['accuracy'])

model.summary()

# 学習してみよう(このコードだけで,学習状況も表示される)
model.fit(x_train, y_train,
        batch_size=20,
        epochs=30,
        verbose=1)

# モデルを評価する(テストデータを使う)
score = model.evaluate(x_test, y_test)
print(score)
print(model.metrics_names)
print(model.metrics_names[0], " : ", score[0])
print(model.metrics_names[1], " : ", score[1])

# 予測してみよう
pred_test = model.predict(x_test)


while True:
    print('---------------------')
    print('予測結果を表示したいテストデータの番号(0 から', x_test.shape[0]-1 ,'まで)を入力してください(-1で終了します):', end="")
    str_idx = input()

    # 空の場合の処理
    if str_idx == "":
        print('入力してください')
        continue
    # 入力した文字列を整数に変換するが,変換できない場合のために例外処理が必要
    try:
        idx = int(str_idx)
    except ValueError:
        print('エラー:数字以外の文字は入力できません')
        continue
    # 終了判定
    if idx == -1:
        break
    if idx < 0 or idx > x_test.shape[0]-1:
        print('正しい値を入れてください')
        continue
    print_train_test(idx)

print('------ 終了しました ------')
Using TensorFlow backend.
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
dense_1 (Dense)              (None, 128)               28928
_________________________________________________________________
dense_2 (Dense)              (None, 10)                1290
=================================================================
Total params: 30,218
Trainable params: 30,218
Non-trainable params: 0
_________________________________________________________________
Epoch 1/30
80/80 [==============================] - 0s 885us/step - loss: 2.2974 - accuracy: 0.2375
Epoch 2/30
80/80 [==============================] - 0s 62us/step - loss: 1.7375 - accuracy: 0.5000

 ...(中略)...

Epoch 30/30
80/80 [==============================] - 0s 224us/step - loss: 0.0309 - accuracy: 1.0000
20/20 [==============================] - 0s 947us/step
[0.7316001653671265, 0.75]
['loss', 'accuracy']
loss  :  0.7316001653671265
accuracy  :  0.75
---------------------
予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):1 ⏎
認識結果  : 4      ←←←←← 認識失敗
正解ラベル: 9
+ + + + + + + + + + + + + + +
+ + + + + + + + + + + + + + +
+ +                       + +
+ +                       + +
+ +                       + +
+ +                       + +
+ + + + +                 + +
+ + + + + + + + + + + + + + +
        + + + + + + + + + + +
                          + +
                          + +
                          + +
                          + +
                          + +
                          + +
---------------------
予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):5 ⏎
認識結果  : 6
正解ラベル: 6

            + + + + + +
          + + + + + + +
          + + +
        + + +
        + + + + + +
      + + + + + + + +
      + +       + + + +
      + +         + + +
      + + +         + +
      + + + +     + + +
        + + + + + + + +
            + + + + +


---------------------
予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):9 ⏎
認識結果  : 0
正解ラベル: 0
          + + + + + +
      + + + + + + + +
    + + + +       + + +
    + + +         + + +
  + + +             + +
  + + +             + +
  + +               + +
  + +             + + +
  + +             + + +
  + +             + +
  + +           + + +
  + + +         + + +
  + + + + + + + + +
    + + + + + + +
        + + + +
---------------------
予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):10 ⏎
認識結果  : 6
正解ラベル: 6
            + + +
            + + +
          + + +
          + + +
          + +
        + + + + + +
        + + + + + + + +
        + +       + + + +
        + +         + + + +
        + +           + + +
        + +             + +
        + + +           + +
        + + + +       + + +
          + + + + + + + + +
            + + + + + + +
---------------------
予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):11 ⏎
認識結果  : 5      ←←←←← 認識失敗
正解ラベル: 6

            + + + + + +
          + + + + + + +
          + + +
        + + +
        + + +
        + +
        + + + + + + +
        + + + + + + + +
        + +       + + + +
        + +         + + +
        + + +       + + +
        + + + + + + + +
          + + + + + + +

---------------------
予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):12 ⏎
認識結果  : 5
正解ラベル: 5
      + +
      + +
      + + + + + + + + + +
      + + + + + + + + + +
      + +
      + +
      + + + + + + +
      + + + + + + + + +
                + + + +
                    + +
                    + +
                  + + +
    + + + + +     + + +
    + + + + + + + + +
          + + + + + +
---------------------
予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):13 ⏎
認識結果  : 0
正解ラベル: 0
        + + + + + +
      + + + + + + + +
    + + + +     + + + +
    + + +         + + + +
    + +             + + +
    + +               + + +
    + +               + + +
    + + +               + +
    + + +               + +
      + + +           + + +
      + + + +       + + +
        + + + + + + + + +
          + + + + + + +


---------------------
予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):14 ⏎
認識結果  : 8
正解ラベル: 8
          + + + + + +
          + + + + + + + +
          + +       + + + +
          + + +         + +
          + + + +   + + +
            + + + + + + +
              + + + + +
            + + + + + + +
          + + + + + + + +
        + + + +       + +
      + + +           + +
      + + +         + + +
      + + + + + + + + + +
        + + + + + + + +

---------------------
予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):15 ⏎
認識結果  : 0
正解ラベル: 0

          + + + +
        + + + + + + +
      + + + +   + + + +
    + + +         + + + +
    + + +           + + +
    + +               + +
  + + +               + +
  + + +             + + +
  + + +             + +
  + + +             + +
    + + +         + + +
    + + + + + + + + +
      + + + + + + +

---------------------
予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):16 ⏎
認識結果  : 1
正解ラベル: 1
            + + +
            + + +
            + + +
            + + +
            + + +
            + + +
            + + +
            + + +
            + + +
            + + +
            + + +
            + + +
            + + +
            + + +
            + + +
---------------------
予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):17 ⏎
認識結果  : 6
正解ラベル: 6
          + + + + + + +
        + + + + + + + +
      + + + +
      + + +
    + + +
    + + + + + + +
    + + + + + + + + +
    + +         + + + +
    + +           + + + +
    + + +           + + +
      + + +           + +
      + + + +       + + +
        + + + + + + + + +
          + + + + + + +

---------------------
予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):18 ⏎
認識結果  : 7
正解ラベル: 7
    + + + + + + + + + + + +
    + + + + + + + + + + + +
                  + + + +
                  + + +
                + + +
                + + +
              + + +
              + + +
            + + +
            + + +
            + +
            + +
            + +
            + +
            + +
---------------------
予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):19 ⏎
認識結果  : 7
正解ラベル: 7
    + + + +
    + + + + + + + + + + + +
        + + + + + + + + + +
                    + + +
                  + + +
                  + + +
                + + +
                + + +
              + + +
              + + +
              + +
            + + +
            + +
            + +
            + +
---------------------
予測結果を表示したいテストデータの番号(0 から 19 まで)を入力してください(-1で終了します):-1 ⏎
------ 終了しました ------

いくつかの数字の認識に失敗していることがわかりました.

ここでは,わずか80個のデータでモデルの学習を行いました.実際にはもっと多くのデータを使って学習する必要があるでしょう.データの数を増やすとともに,中間層のニューロン数(現在は128)や学習時のバッチサイズ(現在は20)を色々変更して試してみると良いでしょう.また,畳み込みやドロップアウトなどを追加したり,学習画像の水増しなども行ってみると良いでしょう.

目次に戻る