次の表はあるスーパーマーケットの売り場面積と売上高の実績値である.ここでは,線形回帰を scipy.stats.linregress を使って行い,新規開店予定の店舗 (P, Q) の売上高を予測しよう.
店名 | 売り場面積 (X) | 売上高 (Y) |
---|---|---|
A | 26 | 65 |
B | 28 | 65 |
C | 25 | 62 |
D | 26 | 59 |
E | 28 | 69 |
F | 33 | 73 |
G | 32 | 79 |
H | 29 | 71 |
I | 30 | 73 |
J | 29 | 71 |
K | 38 | 86 |
L | 40 | 88 |
M | 32 | 71 |
N | 26 | 63 |
O | 34 | 80 |
P | 27 | ? |
Q | 38 | ? |
まずは必要なモジュールをインポートする.
モジュールのインポートimport numpy as np
import pandas as pd
import scipy.stats
import matplotlib.pyplot as plt
# 高解像度ディスプレイ用
from IPython.display import set_matplotlib_formats
# from matplotlib_inline.backend_inline import set_matplotlib_formats # バージョンによってはこちらを有効に
set_matplotlib_formats('retina')
データは GitHub のリポジトリで公開しているので,これを Pandas のデータフレームに直接読み込む.
問題のデータを読み込むurl = "https://github.com/rinsaka/sample-data-sets/blob/master/store-space-sales.csv?raw=true"
df = pd.read_csv(url)
df
Pandas のデータフレーム df
から「space」列と「sales」列を取り出してそれぞれ NumPy 配列に変換する.
Pandas データフレームを NumPy 配列に変換x = df.loc[:, 'space'].values
y = df.loc[:, 'sales'].values
print(x)
print(y)
[26 28 25 26 28 33 32 29 30 29 38 40 32 26 34] [65 65 62 59 69 73 79 71 73 71 86 88 71 63 80]
回帰分析の前に散布図を描いてみよう.
散布図fig, ax = plt.subplots(1, 1, figsize=(6, 4))
ax.scatter(x, y)
ax.set_xlabel('space')
ax.set_ylabel('sales')
ax.set_xlim(20,45)
ax.set_ylim(40,100)
plt.show()
回帰分析は x と y のリストを引数に与えるだけで行える.
回帰分析の実行result = scipy.stats.linregress(x, y)
結果を表示してみよう.
結果の表示result
LinregressResult(slope=1.8335734870317, intercept=15.926032660902997, rvalue=0.9540596161676781, pvalue=3.544817731843698e-08, stderr=0.15970457065329158, intercept_stderr=4.903389816647174)
上の結果から必要な推定結果を次のような方法で取得すると良い.
結果の取得a = result.slope
b = result.intercept
print("傾き:", a.round(4))
print("切片:", b.round(4))
傾き: 1.8336 切片: 15.926
グラフを表示しよう.
グラフを描画fig, ax = plt.subplots(1, 1, figsize=(6, 4))
ax.scatter(x, y)
ax.plot(x, [a * t + b for t in x])
ax.set_xlabel('space')
ax.set_ylabel('sales')
ax.set_xlim(20,45)
ax.set_ylim(40,100)
ax.set_title('OSL results')
plt.show()
新規店舗の売上を予測する.
予測# 予測する
print('P (27) ==> ', a * 27 + b)
print('Q (38) ==> ', a * 38 + b)
P (27) ==> 65.4325168107589 Q (38) ==> 85.60182516810758