ここでは有名なフィッシャーのアヤメ (iris) データを利用してクラスタリングを実行してみます.まず,必要なライブラリをインポートします.
import seaborn as sns
import pandas as pd
import numpy as np
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from IPython.display import set_matplotlib_formats
# from matplotlib_inline.backend_inline import set_matplotlib_formats # バージョンによってはこちらを有効に
set_matplotlib_formats('retina')
Seaborn
に登録された iris
データをロードします.このデータは萼片 (sepal) の長さと幅,花びら (petal) の長さと幅,アヤメの品種 (species) からなるデータです.
df = sns.load_dataset('iris')
print(df)
sepal_length sepal_width petal_length petal_width species 0 5.1 3.5 1.4 0.2 setosa 1 4.9 3.0 1.4 0.2 setosa 2 4.7 3.2 1.3 0.2 setosa 3 4.6 3.1 1.5 0.2 setosa 4 5.0 3.6 1.4 0.2 setosa .. ... ... ... ... ... 145 6.7 3.0 5.2 2.3 virginica 146 6.3 2.5 5.0 1.9 virginica 147 6.5 3.0 5.2 2.0 virginica 148 6.2 3.4 5.4 2.3 virginica 149 5.9 3.0 5.1 1.8 virginica [150 rows x 5 columns]
このデータを概観するには ydata-profiling のページを参照してください.
上で読み込んだアヤメ (iris) データを用いて k-means のクラスタリングを行います.別のページでは2次元のデータでクラスタリングを行いましたが,アヤメのデータは萼片 (sepal) の長さと幅,花びら (petal) の長さと幅からなる4次元のデータです.品種 (species) は3種類 (setosa, versicolor, virginica) あることから,萼片・花びらの長さと幅の4次元データから品種を予測できるかどうかをクラスタリングを実行して考えます.
まず,変数のリストを作成します.
variables = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
データフレームから必要な変数だけを取り出して Numpy 配列に変更します.
mat_x = df[variables].values
print(mat_x)
[[5.1 3.5 1.4 0.2] [4.9 3. 1.4 0.2] [4.7 3.2 1.3 0.2] [4.6 3.1 1.5 0.2] [5. 3.6 1.4 0.2] ... (中略) ... [6.7 3. 5.2 2.3] [6.3 2.5 5. 1.9] [6.5 3. 5.2 2. ] [6.2 3.4 5.4 2.3] [5.9 3. 5.1 1.8]]
クラスタ数は 3 に設定してクラスタリンスを行います.
k = 3 # クラスタ数
clf = KMeans(n_clusters=k) # モデルの設定
# clf = KMeans(n_clusters=k, random_state=1) # 再現性を持たせたい場合
clf.fit(mat_x) # クラスタリングの計算
pred = clf.predict(mat_x) # 計算結果からサンプルデータがどのクラスタに属するかを予測する
print(pred)
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 2 2 2 2 0 2 2 2 2 2 2 0 0 2 2 2 2 0 2 0 2 0 2 2 0 0 2 2 2 2 2 0 2 2 2 2 0 2 2 2 0 2 2 2 0 2 2 0]
結果をデータフレームの列として追加します.
df['cluster_id'] = pred
品種の列とクラスタIDの列だけを表示して結果を確認します.そこそこの精度でクラス分けはできている様子です.
results = df[['species', 'cluster_id']].values
print(results)
[['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['setosa' 1] ['versicolor' 0] ['versicolor' 0] ['versicolor' 2] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 2] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['versicolor' 0] ['virginica' 2] ['virginica' 0] ['virginica' 2] ['virginica' 2] ['virginica' 2] ['virginica' 2] ['virginica' 0] ['virginica' 2] ['virginica' 2] ['virginica' 2] ['virginica' 2] ['virginica' 2] ['virginica' 2] ['virginica' 0] ['virginica' 0] ['virginica' 2] ['virginica' 2] ['virginica' 2] ['virginica' 2] ['virginica' 0] ['virginica' 2] ['virginica' 0] ['virginica' 2] ['virginica' 0] ['virginica' 2] ['virginica' 2] ['virginica' 0] ['virginica' 0] ['virginica' 2] ['virginica' 2] ['virginica' 2] ['virginica' 2] ['virginica' 2] ['virginica' 0] ['virginica' 2] ['virginica' 2] ['virginica' 2] ['virginica' 2] ['virginica' 0] ['virginica' 2] ['virginica' 2] ['virginica' 2] ['virginica' 0] ['virginica' 2] ['virginica' 2] ['virginica' 2] ['virginica' 0] ['virginica' 2] ['virginica' 2] ['virginica' 0]]
横軸に花びらの長さ,縦軸に花びらの幅をとり,散布図を描きます.
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
colors = ['Red', 'Blue', 'Pink']
markers = ['o', 'x', 'v']
for cls in range(k):
x = df.loc[df['cluster_id'] == cls, 'petal_length']
y = df.loc[df['cluster_id'] == cls, 'petal_width']
ax.scatter(x, y,
alpha=0.5,
label=f"cluseter {cls}",
color=colors[cls],
marker=markers[cls]
)
ax.set_title("Clustering results 1")
ax.set_xlabel('petal_length')
ax.set_ylabel('petal_width')
ax.legend(loc='upper left')
# plt.savefig('cluster_ilis1.png', dpi=300, facecolor='white')
plt.show()
他の変数の組み合わせでも確認したいので,すべての変数の組み合わせについて散布図を描きます.
colors = ['Red', 'Blue', 'Pink']
markers = ['o', 'x', 'v']
# グラフ全体の大きさを指定
figsize = (12, 12)
# 行数と列数を指定
rows = len(variables)
cols = len(variables)
fig, ax = plt.subplots(
rows,
cols,
figsize=figsize,
constrained_layout=True, # Subplot 間の間隔を改善
)
for i, col in enumerate(variables):
for j, col in enumerate(variables):
r = i
c = j
for cls in range(k):
x = df.loc[df['cluster_id'] == cls, variables[c]]
y = df.loc[df['cluster_id'] == cls, variables[r]]
ax[r][c].scatter(x, y,
alpha=0.5,
label=f"cluseter {cls}",
color=colors[cls],
marker=markers[cls]
)
if r == rows - 1: # 最後の行だけX軸ラベルを追加
ax[r][c].set_xlabel(variables[c])
if c == 0: # 最初の列だけY軸ラベルを追加
ax[r][c].set_ylabel(variables[r])
ax[r][c].grid()
# plt.savefig('cluster_ilis2.png', dpi=300, facecolor='white')
なお,変数の次元が大きくなると,上のような複数の散布図からデータの特徴を捉えることが難しくなります.そのような場合には主成分分析によって次元を圧縮することでデータの特徴を捉えやすくできる可能性があります.
クラスタリングがどの程度うまくできているかを確認するために,Pandasのグループ化やピボットテーブルを利用して正解率を求めてみます.まず,グループ化を行って,cluster_id
の個数をカウントします.setosa
はすべてクラスタ0に分類されています.versicolor
は48個がクラスタ1に,2個がクラスタ2に分類されています.さらに virginica
は36個がクラスタ2に,14個がクラスタ1に分類されている事がわかります.
tbl = df.groupby(['species','cluster_id'])['cluster_id'].count()
tbl
species cluster_id setosa 0 50 versicolor 1 48 2 2 virginica 1 14 2 36 Name: cluster_id, dtype: int64
ピボットテーブルを作成します.
tbl = df.pivot_table(
'sepal_length', index='species', columns='cluster_id', aggfunc='count'
)
print(tbl)
cluster_id 0 1 2 species setosa 50.0 NaN NaN versicolor NaN 48.0 2.0 virginica NaN 14.0 36.0
欠損値 (NaN) をゼロで埋めるには,fill_value=0
を指定します.
tbl = df.pivot_table(
'sepal_length', index='species', columns='cluster_id', aggfunc='count', fill_value=0
)
print(tbl)
cluster_id 0 1 2 species setosa 50 0 0 versicolor 0 48 2 virginica 0 14 36
lambda
式とapply
を利用して正解率を求めます.setosa
の正解率は 100%,versicolor
は 77.4%,virginica
は 94.7% であることがわかりました.
tbl = df.pivot_table(
'sepal_length', index='species', columns='cluster_id', aggfunc='count', fill_value=0
).apply(lambda x:x/sum(x), axis=0)
print(tbl)
cluster_id 0 1 2 species setosa 1.0 0.000000 0.000000 versicolor 0.0 0.774194 0.052632 virginica 0.0 0.225806 0.947368