前のページでは2次式で回帰式を求めました.ここではさらに項を加えた3次式
まず,必要なライブラリを読み込んだ後に回帰式と残差2乗和を取得する関数を定義します.回帰式(15行目)に3次の項が増えただけです.
import numpy as np
import pandas as pd
import scipy.optimize as optimize # 最適化
import matplotlib.pyplot as plt
# 高解像度ディスプレイ用
from IPython.display import set_matplotlib_formats
# from matplotlib_inline.backend_inline import set_matplotlib_formats # バージョンによってはこちらを有効に
set_matplotlib_formats('retina')
"""
関数の定義
"""
def my_func(w, x):
y = w[0] + w[1] * x + w[2] * x ** 2 + w[3] * x ** 3
return y
"""
残差2乗和を求める関数
Residual Sum-of-Squares
"""
def get_rss(w, x, y):
y_pred = my_func(w, x)
error = (y - y_pred)**2
return np.sum(error)
1次式や2次式の場合とほぼ同じ方法で最適化を行うことができます.僅かな違いは推定するパラメータ数が4に変化したことです.
# CSV ファイルを読み込む
url = "https://github.com/rinsaka/sample-data-sets/blob/master/lr.csv?raw=true"
df = pd.read_csv(url)
# NumPy 配列に変換する
x_data = df.loc[:, 'x'].values
y_data = df.loc[:, 'y'].values
# ネルダーミード法による最適化を行う
w = np.array([0.1, 0.1, 0.1, 0.1]) # 初期値を設定
results_nm = optimize.minimize(get_rss, w, args=(x_data, y_data), method='Nelder-Mead')
print(results_nm)
final_simplex: (array([[ 1.36047925, 3.18744739, -0.57981438, 0.02415969], [ 1.3605272 , 3.1873661 , -0.57979383, 0.02415829], [ 1.36046164, 3.18742938, -0.57980545, 0.02415883], [ 1.36041651, 3.18746626, -0.57981562, 0.02415965], [ 1.36049654, 3.18745606, -0.57981861, 0.02416001]]), array([143.93440151, 143.93440151, 143.93440152, 143.93440152, 143.93440153])) fun: 143.93440151132944 message: 'Optimization terminated successfully.' nfev: 757 nit: 453 status: 0 success: True x: array([ 1.36047925, 3.18744739, -0.57981438, 0.02415969])
上の結果を確認すると,目的関数すなわち残差2乗和が 143.93 となり,2次式の 148.92 から改善されたことがわかります.また推定パラメータ数が4に増えたことで,問題が僅かに難しくなったことから,最適化での繰り返し回数 nit (Number of iterations performed by the optimizer) が1次式の 62,2次式の 143 から 453 に増加していることもわかります.
最後に回帰式を散布図に重ねて描いてみます.下のコードは1次式や2次式の場合から変更の必要はありません.
# 最適解を使って回帰直線のデータを作成する
x_plot = np.linspace(0, 10, 100)
y_pred = my_func(results_nm["x"], x_plot)
# グラフを描く
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
ax.scatter(x_data, y_data, label="data")
ax.plot(x_plot, y_pred, label='model 3')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_xlim(0,10)
ax.set_ylim(0,10)
ax.legend()
# plt.savefig('lr-model3.png', dpi=300, facecolor='white')
plt.show()
比較対象として,2次式の場合の結果も示しておきます.