Python入門トップページ


重回帰分析をしてみよう-1

データの読み込み

まずは,必要なモジュールをインポートします.

モジュールのインポート
import pandas as pd
import numpy as np
import seaborn as sns
import statsmodels.api as sm
import matplotlib.pyplot as plt

# 高解像度ディスプレイ用
from IPython.display import set_matplotlib_formats
# from matplotlib_inline.backend_inline import set_matplotlib_formats # バージョンによってはこちらを有効に
set_matplotlib_formats('retina')

次に,CSVファイルを読み込んで表示してみます.ここで利用するデータは GitHub のリポジトリで公開しているので,プログラムの中で読み込むことができます.このデータは,\(x_1\), \(x_2\), \(x_3\) の値が決まれば \(y\) の値がおおよそ決まるような200件のデータです.

ファイルを読み込んで表示する
url = "https://github.com/rinsaka/sample-data-sets/blob/master/mra-01.csv?raw=true"
df = pd.read_csv(url)
print(df)
      no     x1     x2     x3        y
0      0  9.852  2.531  7.608  124.111
1      1  8.516  3.811  4.539  137.368
2      2  0.416  9.448  1.251  150.623
3      3  7.852  0.355  2.133   90.407
4      4  5.855  9.925  2.039  211.147
..   ...    ...    ...    ...      ...
195  195  1.832  4.736  0.174   96.737
196  196  4.977  8.551  8.811  160.691
197  197  5.427  7.830  3.951  171.948
198  198  7.411  9.409  2.531  219.871
199  199  5.384  2.037  5.684   79.859

[200 rows x 5 columns]

相関係数と散布図行列

重回帰分析を行う前に,変数間の相関係数を計算し,散布図行列も描いて確認してみよう.まずは,ここと同じ方法で相関係数を計算します.これには Pandas のデータフレームを NumPy 配列に変換し,これを転置して np.corrcoef 関数に与えると良いでしょう.

相関係数
# x1, x2, x3, y 列だけとりだして NumPy 配列に変換
xy = df.loc[:,['x1','x2','x3','y']].values
# NumPy配列を転置して相関係数を求める
print(np.corrcoef(xy.T))
[[ 1.          0.03577473 -0.09138881  0.58709634]
 [ 0.03577473  1.          0.03582377  0.81193681]
 [-0.09138881  0.03582377  1.         -0.1891527 ]
 [ 0.58709634  0.81193681 -0.1891527   1.        ]]

上の結果から,\(x_2\)\(y\) には高い正の相関関係 (相関係数 0.8119) があり,次いで \(x_1\)\(y\) にも正の相関関係 (相関係数 0.5871) があることがわかります.また,\(x_3\)\(y\) には弱い負の相関 (相関係数 -0.1891) がある(と言ってよいか,無いというべきか・・・というレベルですが,負の相関があるようなデータにしています(詳細は後述)).さらに,\(x_1\), \(x_2\), \(x_3\) それぞれの間には相関関係はぼぼありません (相関係数は 0.03577, -0.09139, 0.03582).

次に,変数ごとの散布図を描いて変数間の関連を確認します.複数の散布図をまとめた散布図行列を描くには seaborn パッケージを使うと良いでしょう.このとき Pandas のデータフレームには no 列が含まれているので,この列を削除してから散布図を描きます.なお,df.drop でデータフレームの行や列を削除できるが,axis=1とすると列を削除し,axis=0や省略すると行を削除します.

データフレームから散布図行列を描く
# no列を削除したデータフレームについて, x1, x2, x3, y それぞれの組み合わせで散布図行列を描く
sns.pairplot(df.drop('no', axis=1))
plt.show()
seaborn1

上の散布図を確認すると,相関係数と同じ様に,\(x_2\)\(y\) の関連や\(x_1\)\(y\) の関連が強いことがわかります.

データの詳細

ここで使用しているデータは,ダミーデータとして作成したものです.具体的には,\(x_1\)\(x_2\)\(x_3\) は最小値0,最大値10の一様乱数です.これは上の図の対角線にある棒グラフからもそのことが読み取れます.また,\(y\) \begin{eqnarray} y &=& b + w_1 x_1 + w_2 x_2 + w_3 x_3 \end{eqnarray} に僅かな誤差を加えて作成しています.ここで,\( b = 10 \)\( w_1 = 10\)\( w_2 = 15 \)\( w_3 = -3 \) です.例えば,\( \rm{no} = 0 \) のデータは,\( x_1 = 9.852\)\( x_2 = 2.531 \)\( x_3 = 7.608 \) であるので, \begin{eqnarray} y &=& b + w_1 x_1 + w_2 x_2 + w_3 x_3 \\ &=& 10 + 10 * 9.852 + 15 * 2.531 - 3 * 7.608 \\ &=& 123.661 \end{eqnarray} のようになります.実際はこの得られた値に誤差を加えて作成しています.したがって,重回帰分析がうまく実行できれば,\( b = 10 \)\( w_1 = 10\)\( w_2 = 15 \)\( w_3 = -3 \) という結果が得られるはずです.

データの確認
print(df)
      no     x1     x2     x3        y
0      0  9.852  2.531  7.608  124.111
1      1  8.516  3.811  4.539  137.368
2      2  0.416  9.448  1.251  150.623
3      3  7.852  0.355  2.133   90.407
4      4  5.855  9.925  2.039  211.147
..   ...    ...    ...    ...      ...
195  195  1.832  4.736  0.174   96.737
196  196  4.977  8.551  8.811  160.691
197  197  5.427  7.830  3.951  171.948
198  198  7.411  9.409  2.531  219.871
199  199  5.384  2.037  5.684   79.859

[200 rows x 5 columns]

重回帰分析

ここでは,線形(単)回帰分析と同じ方法で重回帰分析を実行してみよう.まずは,データフレームから必要な列を取り出して NumPy 配列に格納します.なお,直接データフレームを指定して重回帰分析を行うことも可能です.

x1, x2, x3 列を NumPy 配列 x_data に格納
x_data = df.loc[:,['x1', 'x2', 'x3']].values
print(x_data)
[[9.852 2.531 7.608]
 [8.516 3.811 4.539]
 [0.416 9.448 1.251]
 [7.852 0.355 2.133]
 [5.855 9.925 2.039]

 ...(中略)...

 [1.832 4.736 0.174]
 [4.977 8.551 8.811]
 [5.427 7.83  3.951]
 [7.411 9.409 2.531]
 [5.384 2.037 5.684]]
y 列を NumPy 配列 y_data に格納
y_data = df.loc[:, 'y'].values
print(y_data)
[124.111 137.368 150.623  90.407 211.147 ...(中略)... 96.737 160.691 171.948 219.871  79.859]

y 切片も含めて重回帰分析をしたいので,定数項を変数xに加えます.加えなければ当てはめた直線(平面,超平面)は原点を通ることになります.(つまり,\(x_1 = 0\)\(x_2 = 0\)\(x_3 = 0\) のとき,\(y = 0\) となります.)

定数項を加える
x = sm.add_constant(x_data)

モデルを定義します.

モデルの定義
model = sm.OLS(y_data, x)

モデルの当てはめ(パラメータ推定)を行います.

パラメータ推定
results = model.fit()

結果を表示します.うまく推定できていることが確認できました.


print(results.params)
[10.21769085  9.99196356 15.05821129 -3.03373984]

詳細な推定結果を表示します.


print(results.summary2())
                 Results: Ordinary least squares
==================================================================
Model:              OLS              Adj. R-squared:     0.999
Dependent Variable: y                AIC:                777.8703
Date:               20yy-mm-dd hh:mm BIC:                791.0636
No. Observations:   200              Log-Likelihood:     -384.94
Df Model:           3                F-statistic:        6.832e+04
Df Residuals:       196              Prob (F-statistic): 1.27e-295
R-squared:          0.999            Scale:              2.8058
--------------------------------------------------------------------
           Coef.    Std.Err.      t       P>|t|     [0.025    0.975]
--------------------------------------------------------------------
const     10.2177     0.3792    26.9483   0.0000    9.4699   10.9654
x1         9.9920     0.0408   244.8085   0.0000    9.9115   10.0725
x2        15.0582     0.0417   361.1830   0.0000   14.9760   15.1404
x3        -3.0337     0.0400   -75.7727   0.0000   -3.1127   -2.9548
------------------------------------------------------------------
Omnibus:               99.771       Durbin-Watson:          2.041
Prob(Omnibus):         0.000        Jarque-Bera (JB):       13.187
Skew:                  0.169        Prob(JB):               0.001
Kurtosis:              1.788        Condition No.:          30
==================================================================

上の結果を確認すると,決定係数 (R-squared : \(R^2\)) が 0.999,自由度補正済み決定係数 (Adj. R-squared) も 0.999 となっており,ほぼ完璧に推定できていることがわかります.次に,\( (x_1, x_2, x_3) = (1, 2, 3)\) というようなデータに対する予測を行ってみます.

予測 (1)
x1, x2, x3 = 1, 2, 3
y = results.params[0] + results.params[1] * x1 + results.params[2] * x2 + results.params[3] * x3
print(y)
41.22485746231744

また,\( (x_1, x_2, x_3) = (7, 5, 3)\) というようなデータに対しても予測を行ってみます.

予測 (2)
x1, x2, x3 = 7, 5, 3
y = results.params[0] + results.params[1] * x1 + results.params[2] * x2 + results.params[3] * x3
print(y)
146.3512726961657

なお,yのデータと重回帰モデルによるyの予測値を散布図としてプロットすると,ほぼ直線になっている(うまく予測できている)ことがわかります.

散布図を作成する
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
ax.scatter(y_data, results.fittedvalues)
ax.set_xlabel('y_data')
ax.set_ylabel('y_pred')
ax.set_xlim(0,250)
ax.set_ylim(0,250)
# plt.savefig('mra1.png', dpi=300, facecolor='white')
plt.show()
mra1