Python入門トップページ


重回帰分析をしてみよう-1

データの読み込み

まずは,必要なモジュールをインポートする.

モジュールのインポートimport pandas as pd
import numpy as np
import seaborn as sns
import statsmodels.api as sm
import matplotlib.pyplot as plt

# 高解像度ディスプレイ用
from IPython.display import set_matplotlib_formats
# from matplotlib_inline.backend_inline import set_matplotlib_formats # バージョンによってはこちらを有効に
set_matplotlib_formats('retina')

次に,CSVファイルを読み込んで表示してみる.ここで利用するデータは GitHub のリポジトリで公開しているので,プログラムの中で読み込むことができる.このデータは,\(x_1\), \(x_2\), \(x_3\) の値が決まれば \(y\) の値がおおよそ決まるような200件のデータである.

ファイルを読み込んで表示するurl = "https://github.com/rinsaka/sample-data-sets/blob/master/mra-01.csv?raw=true"
df = pd.read_csv(url)
print(df)
      no     x1     x2     x3        y
0      0  9.852  2.531  7.608  124.111
1      1  8.516  3.811  4.539  137.368
2      2  0.416  9.448  1.251  150.623
3      3  7.852  0.355  2.133   90.407
4      4  5.855  9.925  2.039  211.147
..   ...    ...    ...    ...      ...
195  195  1.832  4.736  0.174   96.737
196  196  4.977  8.551  8.811  160.691
197  197  5.427  7.830  3.951  171.948
198  198  7.411  9.409  2.531  219.871
199  199  5.384  2.037  5.684   79.859

[200 rows x 5 columns]

相関係数と散布図行列

重回帰分析を行う前に,変数間の相関係数を計算し,散布図行列も描いて確認してみよう.まずは,ここと同じ方法で相関係数を計算する.これには Pandas のデータフレームを NumPy 配列に変換し,これを転置して np.corrcoef 関数に与えると良い.

相関係数# x1, x2, x3, y 列だけとりだして NumPy 配列に変換
xy = df.loc[:,['x1','x2','x3','y']].values
# NumPy配列を転置して相関係数を求める
print(np.corrcoef(xy.T))
[[ 1.          0.03577473 -0.09138881  0.58709634]
 [ 0.03577473  1.          0.03582377  0.81193681]
 [-0.09138881  0.03582377  1.         -0.1891527 ]
 [ 0.58709634  0.81193681 -0.1891527   1.        ]]

上の結果から,\(x_2\)\(y\) には高い正の相関関係 (相関係数 0.8119) があり,次いで \(x_1\)\(y\) にも正の相関関係 (相関係数 0.5871) があることがわかる.また,\(x_3\)\(y\) には弱い負の相関 (相関係数 -0.1891) がある(と言ってよいか,無いというべきか・・・というレベルですが,負の相関があるようなデータにしています(詳細は後述)).さらに,\(x_1\), \(x_2\), \(x_3\) それぞれの間には相関関係はぼぼありません (相関係数は 0.03577, -0.09139, 0.03582).

次に,変数ごとの散布図を描いて変数間の関連を確認する.複数の散布図をまとめた散布図行列を描くには seaborn パッケージを使うとよい.このとき Pandas のデータフレームには no 列が含まれているので,この列を削除してから散布図を描く.なお,df.drop でデータフレームの行や列を削除できるが,axis=1とすると列を削除し,axis=0や省略すると行を削除する.

データフレームから散布図行列を描く# no列を削除したデータフレームについて, x1, x2, x3, y それぞれの組み合わせで散布図行列を描く
sns.pairplot(df.drop('no', axis=1))
plt.show()
seaborn1

上の散布図を確認すると,相関係数と同じ様に,\(x_2\)\(y\) の関連や\(x_1\)\(y\) の関連が強いことがわかる.

データの詳細

ここで使用しているデータは,ダミーデータとして作成したものです.具体的には,\(x_1\)\(x_2\)\(x_3\) は最小値0,最大値10の一様乱数です.これは上の図の対角線にある棒グラフからもそのことが読み取れます.また,\(y\) \begin{eqnarray} y &=& b + w_1 x_1 + w_2 x_2 + w_3 x_3 \end{eqnarray} に僅かな誤差を加えて作成しています.ここで,\( b = 10 \)\( w_1 = 10\)\( w_2 = 15 \)\( w_3 = -3 \) です.例えば,\( \rm{no} = 0 \) のデータは,\( x_1 = 9.852\)\( x_2 = 2.531 \)\( x_3 = 7.608 \) であるので, \begin{eqnarray} y &=& b + w_1 x_1 + w_2 x_2 + w_3 x_3 \\ &=& 10 + 10 * 9.852 + 15 * 2.531 - 3 * 7.608 \\ &=& 123.661 \end{eqnarray} のようになる.実際はこの得られた値に誤差を加えて作成しています.したがって,重回帰分析がうまく実行できれば,\( b = 10 \)\( w_1 = 10\)\( w_2 = 15 \)\( w_3 = -3 \) という結果が得られるはずです.

データの確認print(df)
      no     x1     x2     x3        y
0      0  9.852  2.531  7.608  124.111
1      1  8.516  3.811  4.539  137.368
2      2  0.416  9.448  1.251  150.623
3      3  7.852  0.355  2.133   90.407
4      4  5.855  9.925  2.039  211.147
..   ...    ...    ...    ...      ...
195  195  1.832  4.736  0.174   96.737
196  196  4.977  8.551  8.811  160.691
197  197  5.427  7.830  3.951  171.948
198  198  7.411  9.409  2.531  219.871
199  199  5.384  2.037  5.684   79.859

[200 rows x 5 columns]

重回帰分析

ここでは,線形(単)回帰分析と同じ方法で重回帰分析を実行してみよう.まずは,データフレームから必要な列を取り出して NumPy 配列に格納する.なお,直接データフレームを指定して重回帰分析を行うことも可能です.

x1, x2, x3 列を NumPy 配列 x_data に格納x_data = df.loc[:,['x1', 'x2', 'x3']].values
print(x_data)
[[9.852 2.531 7.608]
 [8.516 3.811 4.539]
 [0.416 9.448 1.251]
 [7.852 0.355 2.133]
 [5.855 9.925 2.039]

 ...(中略)...

 [1.832 4.736 0.174]
 [4.977 8.551 8.811]
 [5.427 7.83  3.951]
 [7.411 9.409 2.531]
 [5.384 2.037 5.684]]
y 列を NumPy 配列 y_data に格納y_data = df.loc[:, 'y'].values
print(y_data)
[124.111 137.368 150.623  90.407 211.147 ...(中略)... 96.737 160.691 171.948 219.871  79.859]

y 切片も含めて重回帰分析をしたいので,定数項を変数xに加える.加えなければ当てはめた直線(平面,超平面)は原点を通ることになる.(つまり,\(x_1 = 0\)\(x_2 = 0\)\(x_3 = 0\) のとき,\(y = 0\) となる.)

定数項を加えるx = sm.add_constant(x_data)

モデルを定義する.

モデルの定義model = sm.OLS(y_data, x)

モデルの当てはめ(パラメータ推定)を行う.

パラメータ推定results = model.fit()

結果を表示する.うまく推定できていることが確認できた.

print(results.params)
[10.21769085  9.99196356 15.05821129 -3.03373984]

詳細な推定結果を表示する.

print(results.summary2())
                 Results: Ordinary least squares
==================================================================
Model:              OLS              Adj. R-squared:     0.999
Dependent Variable: y                AIC:                777.8703
Date:               20yy-mm-dd hh:mm BIC:                791.0636
No. Observations:   200              Log-Likelihood:     -384.94
Df Model:           3                F-statistic:        6.832e+04
Df Residuals:       196              Prob (F-statistic): 1.27e-295
R-squared:          0.999            Scale:              2.8058
--------------------------------------------------------------------
           Coef.    Std.Err.      t       P>|t|     [0.025    0.975]
--------------------------------------------------------------------
const     10.2177     0.3792    26.9483   0.0000    9.4699   10.9654
x1         9.9920     0.0408   244.8085   0.0000    9.9115   10.0725
x2        15.0582     0.0417   361.1830   0.0000   14.9760   15.1404
x3        -3.0337     0.0400   -75.7727   0.0000   -3.1127   -2.9548
------------------------------------------------------------------
Omnibus:               99.771       Durbin-Watson:          2.041
Prob(Omnibus):         0.000        Jarque-Bera (JB):       13.187
Skew:                  0.169        Prob(JB):               0.001
Kurtosis:              1.788        Condition No.:          30
==================================================================

上の結果を確認すると,決定係数 (R-squared : \(R^2\)) が 0.999,自由度補正済み決定係数 (Adj. R-squared) も 0.999 となっており,ほぼ完璧に推定できていることがわかる.次に,\( (x_1, x_2, x_3) = (1, 2, 3)\) というようなデータに対する予測を行ってみる.

予測 (1)x1, x2, x3 = 1, 2, 3
y = results.params[0] + results.params[1] * x1 + results.params[2] * x2 + results.params[3] * x3
print(y)
41.22485746231744

また,\( (x_1, x_2, x_3) = (7, 5, 3)\) というようなデータに対しても予測を行ってみる.

予測 (2)x1, x2, x3 = 7, 5, 3
y = results.params[0] + results.params[1] * x1 + results.params[2] * x2 + results.params[3] * x3
print(y)
146.3512726961657

なお,yのデータと重回帰モデルによるyの予測値を散布図としてプロットすると,ほぼ直線になっている(うまく予測できている)ことがわかる.

散布図を作成するfig, ax = plt.subplots(1, 1, figsize=(6, 4))
ax.scatter(y_data, results.fittedvalues)
ax.set_xlabel('y_data')
ax.set_ylabel('y_pred')
ax.set_xlim(0,250)
ax.set_ylim(0,250)
# plt.savefig('mra1.png', dpi=300, facecolor='white')
plt.show()
mra1