次の表はあるスーパーマーケットの売り場面積と売上高の実績値です.ここでは,線形回帰を scipy.stats.linregress を使って行い,新規開店予定の店舗 (P, Q) の売上高を予測しよう.
店名 | 売り場面積 (X) | 売上高 (Y) |
---|---|---|
A | 26 | 65 |
B | 28 | 65 |
C | 25 | 62 |
D | 26 | 59 |
E | 28 | 69 |
F | 33 | 73 |
G | 32 | 79 |
H | 29 | 71 |
I | 30 | 73 |
J | 29 | 71 |
K | 38 | 86 |
L | 40 | 88 |
M | 32 | 71 |
N | 26 | 63 |
O | 34 | 80 |
P | 27 | ? |
Q | 38 | ? |
まずは必要なモジュールをインポートします.
モジュールのインポート
import numpy as np
import pandas as pd
import scipy.stats
import matplotlib.pyplot as plt
# 高解像度ディスプレイ用
from IPython.display import set_matplotlib_formats
# from matplotlib_inline.backend_inline import set_matplotlib_formats # バージョンによってはこちらを有効に
set_matplotlib_formats('retina')
データは GitHub のリポジトリで公開しているので,これを Pandas のデータフレームに直接読み込みます.
問題のデータを読み込む
url = "https://github.com/rinsaka/sample-data-sets/blob/master/store-space-sales.csv?raw=true"
df = pd.read_csv(url)
df
Pandas のデータフレーム df
から「space」列と「sales」列を取り出してそれぞれ NumPy 配列に変換します.
Pandas データフレームを NumPy 配列に変換
x = df.loc[:, 'space'].values
y = df.loc[:, 'sales'].values
print(x)
print(y)
[26 28 25 26 28 33 32 29 30 29 38 40 32 26 34] [65 65 62 59 69 73 79 71 73 71 86 88 71 63 80]
回帰分析の前に散布図を描いてみよう.
散布図
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
ax.scatter(x, y)
ax.set_xlabel('space')
ax.set_ylabel('sales')
ax.set_xlim(20,45)
ax.set_ylim(40,100)
plt.show()
回帰分析は x と y のリストを引数に与えるだけで行えます.
回帰分析の実行
result = scipy.stats.linregress(x, y)
結果を表示してみよう.
結果の表示
result
LinregressResult(slope=1.8335734870317, intercept=15.926032660902997, rvalue=0.9540596161676781, pvalue=3.544817731843698e-08, stderr=0.15970457065329158, intercept_stderr=4.903389816647174)
上の結果から必要な推定結果を次のような方法で取得すると良いでしょう.
結果の取得
a = result.slope
b = result.intercept
print(f'傾き: {a:7.4f}')
print(f'切片: {b:7.4f}')
傾き: 1.8336 切片: 15.9260
グラフを表示しよう.
グラフを描画
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
ax.scatter(x, y)
ax.plot(x, [a * t + b for t in x])
ax.set_xlabel('space')
ax.set_ylabel('sales')
ax.set_xlim(20,45)
ax.set_ylim(40,100)
ax.set_title('OSL results')
plt.show()
新規店舗の売上を予測します.
予測
print(f'P (27) ==> {a * 27 + b:.2f}')
print(f'Q (38) ==> {a * 38 + b:.2f}')
P (27) ==> 65.43 Q (38) ==> 85.60