Python入門トップページ


手書き数字を認識するAIを作ってみよう : 目次

  1. 画像データの準備と確認
  2. 画像データを読み込んでみよう
  3. 画像データの一覧を読み込んでみよう
  4. 学習データとテストデータを準備する
  5. 保存したデータを開いてみる
  6. モデルを作る
  7. 学習(トレーニング)させてみよう
  8. モデルを評価しよう
  9. 学習データで認識させてみよう(1)
  10. 学習データで認識させてみよう(2)
  11. 学習データで認識させてみよう(3)
  12. テストデータで認識させてみよう
  13. モデルと重みパラメータを保存しよう
  14. 学習済みモデルをロードしよう
  15. 学習済みモデルをロードして,認識してみよう

画像データの一覧を読み込んでみよう

ここでは,png フォルダ(ディレクトリ)にある png 画像をまとめて読み込み,学習用データやテスト用データを生成してみる.

ディレクトリを走査する

まず,png ディレクトリを走査して,png ファイルの一覧を取得してみる.

ディレクトリを走査する (02-readfiles.py)
import glob
import os

# png 画像データが保存されているディレクトリの指定
search_path = os.path.sep.join(['png', '*.png'])
# ファイルを検索
files = glob.glob(search_path)

print(files)
['png\\0-00.png', 'png\\0-01.png', 'png\\0-02.png', 'png\\0-03.png',
'png\\0-04.png', 'png\\0-05.png', 'png\\0-06.png', 'png\\0-07.png',
'png\\0-08.png', 'png\\0-09.png', 'png\\1-00.png', 'png\\1-01.png',
'png\\1-02.png', 'png\\1-03.png', 'png\\1-04.png', 'png\\1-05.png',
'png\\1-06.png', 'png\\1-07.png', 'png\\1-08.png', 'png\\1-09.png',
'png\\2-00.png', 'png\\2-01.png', 'png\\2-02.png', 'png\\2-03.png',
'png\\2-04.png', 'png\\2-05.png', 'png\\2-06.png', 'png\\2-07.png',
'png\\2-08.png', 'png\\2-09.png', 'png\\3-00.png', 'png\\3-01.png',
'png\\3-02.png', 'png\\3-03.png', 'png\\3-04.png', 'png\\3-05.png',
'png\\3-06.png', 'png\\3-07.png', 'png\\3-08.png', 'png\\3-09.png',
'png\\4-00.png', 'png\\4-01.png', 'png\\4-02.png', 'png\\4-03.png',
'png\\4-04.png', 'png\\4-05.png', 'png\\4-06.png', 'png\\4-07.png',
'png\\4-08.png', 'png\\4-09.png', 'png\\5-00.png', 'png\\5-01.png',
'png\\5-02.png', 'png\\5-03.png', 'png\\5-04.png', 'png\\5-05.png',
'png\\5-06.png', 'png\\5-07.png', 'png\\5-08.png', 'png\\5-09.png',
'png\\6-00.png', 'png\\6-01.png', 'png\\6-02.png', 'png\\6-03.png',
'png\\6-04.png', 'png\\6-05.png', 'png\\6-06.png', 'png\\6-07.png',
'png\\6-08.png', 'png\\6-09.png', 'png\\7-00.png', 'png\\7-01.png',
'png\\7-02.png', 'png\\7-03.png', 'png\\7-04.png', 'png\\7-05.png',
'png\\7-06.png', 'png\\7-07.png', 'png\\7-08.png', 'png\\7-09.png',
'png\\8-00.png', 'png\\8-01.png', 'png\\8-02.png', 'png\\8-03.png',
'png\\8-04.png', 'png\\8-05.png', 'png\\8-06.png', 'png\\8-07.png',
'png\\8-08.png', 'png\\8-09.png', 'png\\9-00.png', 'png\\9-01.png',
'png\\9-02.png', 'png\\9-03.png', 'png\\9-04.png', 'png\\9-05.png',
'png\\9-06.png', 'png\\9-07.png', 'png\\9-08.png', 'png\\9-09.png']

次に,すべてのファイルを読み込んで,1つのデータを表示してみよう.ここで,20行目の png_files のように,Python のリストの添字(インデックス)は0番目からスタートすることに注意しよう.つまり,0〜99の添字から任意に選んでよい.

すべてを読み込み,1つを表示 (02-readfiles.py)
import glob
import os
from PIL import Image
import numpy as np

# png 画像データが保存されているディレクトリの指定
search_path = os.path.sep.join(['png', '*.png'])
# ファイルを検索
files = glob.glob(search_path)

# 全画像データを格納するためのリストを準備する
png_files = []

for file in files:
    img = Image.open(file)
    img = img.convert("P") # 白黒データに変換
    data = np.asarray(img) # numpy 配列に変換
    png_files.append(data) # リストに追加

print(png_files[1])  # 数字(0-99)を適当に変えて試すと良い
[[0 0 0 0 1 1 1 1 1 1 0 0 0 0 0]
 [0 0 0 1 1 1 1 1 1 1 1 0 0 0 0]
 [0 0 1 1 1 1 0 0 1 1 1 1 0 0 0]
 [0 0 1 1 1 0 0 0 0 1 1 1 1 0 0]
 [0 0 1 1 0 0 0 0 0 0 1 1 1 0 0]
 [0 0 1 1 0 0 0 0 0 0 0 1 1 1 0]
 [0 0 1 1 0 0 0 0 0 0 0 1 1 1 0]
 [0 0 1 1 1 0 0 0 0 0 0 0 1 1 0]
 [0 0 1 1 1 0 0 0 0 0 0 0 1 1 0]
 [0 0 0 1 1 1 0 0 0 0 0 1 1 1 0]
 [0 0 0 1 1 1 1 0 0 0 1 1 1 0 0]
 [0 0 0 0 1 1 1 1 1 1 1 1 1 0 0]
 [0 0 0 0 0 1 1 1 1 1 1 1 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]

目次に戻る

カテゴリデータも取得する

画像データだけでなく,正解ラベルとなるカテゴリデータも取得しておこう.これはファイル名から取得できる.つまりファイルパス名が「png/8-01.png」であれば,正解ラベルは「8」である.つまり,ファイルパス名の「4」番目の文字を取り出せば良い.なお,次の実行例では「png/1-01.png」から正解ラベル「1」を取り出している.

正解ラベルを取得する (02-readfiles.py)
import glob
import os
from PIL import Image
import numpy as np

# png 画像データが保存されているディレクトリの指定
search_path = os.path.sep.join(['png', '*.png'])
# ファイルを検索
files = glob.glob(search_path)

# 全画像データを格納するためのリストを準備する
png_files = []

for file in files:
    img = Image.open(file)
    img = img.convert("P") # 白黒データに変換
    data = np.asarray(img) # numpy 配列に変換
    cat = file[4]
    png_files.append([cat, data]) # リストに追加

print(png_files[1])  # 数字(0-99)を適当に変えて試すと良い
['0', array([[0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
       [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
       [0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0],
       [0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0],
       [0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0],
       [0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0],
       [0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0],
       [0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
       [0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
       [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0],
       [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0],
       [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
       [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=uint8)]

目次に戻る

NumPy 配列を一次元化する

画像データは2次元配列の形式になっているが,ニューラルネットワークの入力データとして利用するために,これを一次元配列に変更する.

一次元化する (02-readfiles.py)
import glob
import os
from PIL import Image
import numpy as np

# png 画像データが保存されているディレクトリの指定
search_path = os.path.sep.join(['png', '*.png'])
# ファイルを検索
files = glob.glob(search_path)

# 全画像データを格納するためのリストを準備する
png_files = []

for file in files:
    img = Image.open(file)
    img = img.convert("P") # 白黒データに変換
    data = np.asarray(img) # numpy 配列に変換
    cat = file[4]
    data = data.flatten()  # numpy 配列を一次元化
    png_files.append([cat, data]) # リストに追加

print(png_files[1])  # 数字(0-99)を適当に変えて試すと良い
['0', array([0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
       1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0,
       0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0,
       0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
       1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1,
       0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
       1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1,
       1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0], dtype=uint8)]

目次に戻る