Python入門トップページ


JIT を使って Python プログラムを高速化しよう

はじめに

Python はインタプリタ言語であるため,プログラムの動作はそれほど高速ではありません.特に,大きなデータを用いた繰り返し処理では多くの計算時間が必要になることがあります.これを解決するために Numpy を使って効率的にデータを処理したり,Cython を使って C言語のプログラムにコンパイル(変換)したりします.

ここでは別の方法として JIT (Just In Time) コンパイラを使った高速化を試します.JIT を使うと,わずか数行のコードを追加するだけで高速化が実現できます.

準備

JIT を利用するには次のコマンドで numba をインストールします.

% pip install numba ⏎
Collecting numba
  Downloading numba-0.60.0-cp311-cp311-macosx_11_0_arm64.whl.metadata (2.7 kB)
Collecting llvmlite<0.44,>=0.43.0dev0 (from numba)
  Downloading llvmlite-0.43.0-cp311-cp311-macosx_11_0_arm64.whl.metadata (4.8 kB)
Requirement already satisfied: numpy<2.1,>=1.22 in /Users/rinsaka/miniforge3/envs/py311Jupyter/lib/python3.11/site-packages (from numba) (1.26.4)
Downloading numba-0.60.0-cp311-cp311-macosx_11_0_arm64.whl (2.6 MB)
   ━━━━━━━━━━━━━━━━━━━━━━ 2.6/2.6 MB 26.2 MB/s eta 0:00:00
Downloading llvmlite-0.43.0-cp311-cp311-macosx_11_0_arm64.whl (28.8 MB)
   ━━━━━━━━━━━━━━━━━━━━ 28.8/28.8 MB 44.4 MB/s eta 0:00:00
Installing collected packages: llvmlite, numba
Successfully installed llvmlite-0.43.0 numba-0.60.0

遅い Python コードを作成

まず,JIT を使わずに処理の遅い Python コードを準備します.このとき,時間のかかる処理を関数化しておくことがポイントです.次の関数 sqrtlist()0 から n までの値について平方根を求めて,それを巨大なリストにしています.さらに計算した平方根のリストの中からインデックス 4,084,441 を返すというものです.この関数を利用して 5億までの平方根を求めると25秒程度の計算時間を要しました.


import math
import time

def sqrtlist(n):
    y = [0] * (n+1)
    for i in range(0, n+1):
        y[i] = math.sqrt(i)
    return y[4_084_441]

print("計算処理開始")
time_start = time.time()
n = 500_000_000
s = sqrtlist(n)
print(s)
time_end = time.time()
time_gap = time_end - time_start
print(f"処理時間は {time_gap:.2f} 秒でした")
計算処理開始
2021.0
処理時間は 24.95 秒でした

上の処理では関数 sqrtlist の中で5億回もの繰り返し(厳密には5億1回)が実行されていることに注意して下さい.

JIT で高速化

非常に時間のかかる処理を JIT で高速化しましょう.JIT の利用は実は非常に簡単です.次の通り,numba をインポートし,高速化したい関数の直前にデコレータを追加するだけです.これによってプログラムの動作が驚くほど高速化されます.実行時にプログラムをコンパイルするための処理時間が必要になりますが,sqrtlist 関数自体の実行時間は25秒からわずか1.00秒にまで高速化できました.


import math
import time
from numba import jit

@jit(nopython=True)
def sqrtlist(n):
    y = [0] * (n+1)
    for i in range(0, n+1):
        y[i] = math.sqrt(i)
    return y[4_084_441]

print("JITコンパイル終了:計算処理開始")
time_start = time.time()
n = 500_000_000
s = sqrtlist(n)
print(s)
time_end = time.time()
time_gap = time_end - time_start
print(f"処理時間は {time_gap:.2f} 秒でした")
JITコンパイル終了:計算処理開始
2021
処理時間は 1.00 秒でした

なお,実行時にはどのタイミングでそれぞれの文字列が出力されるか,注意深く観察してください.JITのプログラムでは,「JITコンパイル終了:計算処理開始」という文字列が表示されるまでに若干時間を要していることでしょう.これは実行時にコンパイルの処理が行われていることを意味しています.コンパイルができてしまえば動作が非常に高速になることが理解できるはずです.