Python入門トップページ


アヤメ (iris) データを k-means でクラスタリングをしてみよう

目次

  1. データの準備
  2. クラスタリング
  3. グラフを描く
  4. 正解率を求める

目次に戻る

データの準備

ここでは有名なフィッシャーのアヤメ (iris) データを利用してクラスタリングを実行してみます.まず,必要なライブラリをインポートします.


import seaborn as sns
import pandas as pd
import numpy as np
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

from IPython.display import set_matplotlib_formats
# from matplotlib_inline.backend_inline import set_matplotlib_formats # バージョンによってはこちらを有効に
set_matplotlib_formats('retina')

Seaborn に登録された iris データをロードします.このデータは萼片 (sepal) の長さと幅,花びら (petal) の長さと幅,アヤメの品種 (species) からなるデータです.


df = sns.load_dataset('iris')
print(df)
     sepal_length  sepal_width  petal_length  petal_width    species
0             5.1          3.5           1.4          0.2     setosa
1             4.9          3.0           1.4          0.2     setosa
2             4.7          3.2           1.3          0.2     setosa
3             4.6          3.1           1.5          0.2     setosa
4             5.0          3.6           1.4          0.2     setosa
..            ...          ...           ...          ...        ...
145           6.7          3.0           5.2          2.3  virginica
146           6.3          2.5           5.0          1.9  virginica
147           6.5          3.0           5.2          2.0  virginica
148           6.2          3.4           5.4          2.3  virginica
149           5.9          3.0           5.1          1.8  virginica

[150 rows x 5 columns]

このデータを概観するには ydata-profiling のページを参照してください.

目次に戻る

クラスタリング

上で読み込んだアヤメ (iris) データを用いて k-means のクラスタリングを行います.別のページでは2次元のデータでクラスタリングを行いましたが,アヤメのデータは萼片 (sepal) の長さと幅,花びら (petal) の長さと幅からなる4次元のデータです.品種 (species) は3種類 (setosa, versicolor, virginica) あることから,萼片・花びらの長さと幅の4次元データから品種を予測できるかどうかをクラスタリングを実行して考えます.

まず,変数のリストを作成します.


variables = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']

データフレームから必要な変数だけを取り出して Numpy 配列に変更します.


mat_x = df[variables].values
print(mat_x)
[[5.1 3.5 1.4 0.2]
 [4.9 3.  1.4 0.2]
 [4.7 3.2 1.3 0.2]
 [4.6 3.1 1.5 0.2]
 [5.  3.6 1.4 0.2]

 ... (中略) ...

 [6.7 3.  5.2 2.3]
 [6.3 2.5 5.  1.9]
 [6.5 3.  5.2 2. ]
 [6.2 3.4 5.4 2.3]
 [5.9 3.  5.1 1.8]]

クラスタ数は 3 に設定してクラスタリンスを行います.


k = 3 # クラスタ数
clf = KMeans(n_clusters=k) # モデルの設定
# clf = KMeans(n_clusters=k, random_state=1) # 再現性を持たせたい場合
clf.fit(mat_x) # クラスタリングの計算
pred = clf.predict(mat_x) # 計算結果からサンプルデータがどのクラスタに属するかを予測する
print(pred)
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 2 2 2 2 0 2 2 2 2
 2 2 0 0 2 2 2 2 0 2 0 2 0 2 2 0 0 2 2 2 2 2 0 2 2 2 2 0 2 2 2 0 2 2 2 0 2
 2 0]

結果をデータフレームの列として追加します.


df['cluster_id'] = pred

品種の列とクラスタIDの列だけを表示して結果を確認します.そこそこの精度でクラス分けはできている様子です.


results = df[['species', 'cluster_id']].values
print(results)
[['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['setosa' 1]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 2]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 2]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['versicolor' 0]
 ['virginica' 2]
 ['virginica' 0]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 0]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 0]
 ['virginica' 0]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 0]
 ['virginica' 2]
 ['virginica' 0]
 ['virginica' 2]
 ['virginica' 0]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 0]
 ['virginica' 0]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 0]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 0]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 0]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 0]
 ['virginica' 2]
 ['virginica' 2]
 ['virginica' 0]]

目次に戻る

グラフを描く

横軸に花びらの長さ,縦軸に花びらの幅をとり,散布図を描きます.


fig, ax = plt.subplots(1, 1, figsize=(6, 6))
colors = ['Red', 'Blue', 'Pink']
markers = ['o', 'x', 'v']

for cls in range(k):
    x = df.loc[df['cluster_id'] == cls, 'petal_length']
    y = df.loc[df['cluster_id'] == cls, 'petal_width']
    ax.scatter(x, y,
        alpha=0.5,
        label=f"cluseter {cls}",
        color=colors[cls],
        marker=markers[cls]
    )

ax.set_title("Clustering results 1")
ax.set_xlabel('petal_length')
ax.set_ylabel('petal_width')
ax.legend(loc='upper left')
# plt.savefig('cluster_ilis1.png', dpi=300, facecolor='white')
plt.show()
cluster_ilis1.png

他の変数の組み合わせでも確認したいので,すべての変数の組み合わせについて散布図を描きます.


colors = ['Red', 'Blue', 'Pink']
markers = ['o', 'x', 'v']
# グラフ全体の大きさを指定
figsize = (12, 12)

# 行数と列数を指定
rows = len(variables)
cols = len(variables)

fig, ax = plt.subplots(
    rows,
    cols,
    figsize=figsize,
    constrained_layout=True, # Subplot 間の間隔を改善
)

for i, col in enumerate(variables):
    for j, col in enumerate(variables):
        r = i
        c = j
        for cls in range(k):
            x = df.loc[df['cluster_id'] == cls, variables[c]]
            y = df.loc[df['cluster_id'] == cls, variables[r]]
            ax[r][c].scatter(x, y,
                    alpha=0.5,
                    label=f"cluseter {cls}",
                    color=colors[cls],
                    marker=markers[cls]
            )
        if r == rows - 1:  # 最後の行だけX軸ラベルを追加
            ax[r][c].set_xlabel(variables[c])
        if c == 0:  # 最初の列だけY軸ラベルを追加
            ax[r][c].set_ylabel(variables[r])
        ax[r][c].grid()
# plt.savefig('cluster_ilis2.png', dpi=300, facecolor='white')
cluster_ilis2.png

なお,変数の次元が大きくなると,上のような複数の散布図からデータの特徴を捉えることが難しくなります.そのような場合には主成分分析によって次元を圧縮することでデータの特徴を捉えやすくできる可能性があります.

目次に戻る

正解率を求める

クラスタリングがどの程度うまくできているかを確認するために,Pandasのグループ化やピボットテーブルを利用して正解率を求めてみます.まず,グループ化を行って,cluster_id の個数をカウントします.setosa はすべてクラスタ0に分類されています.versicolor は48個がクラスタ1に,2個がクラスタ2に分類されています.さらに virginicaは36個がクラスタ2に,14個がクラスタ1に分類されている事がわかります.


tbl = df.groupby(['species','cluster_id'])['cluster_id'].count()
tbl
       species     cluster_id
setosa      0             50
versicolor  1             48
            2              2
virginica   1             14
            2             36
Name: cluster_id, dtype: int64

ピボットテーブルを作成します.


tbl = df.pivot_table(
    'sepal_length', index='species', columns='cluster_id', aggfunc='count'
    )
print(tbl)
cluster_id     0     1     2
  species
  setosa      50.0   NaN   NaN
  versicolor   NaN  48.0   2.0
  virginica    NaN  14.0  36.0

欠損値 (NaN) をゼロで埋めるには,fill_value=0 を指定します.


tbl = df.pivot_table(
    'sepal_length', index='species', columns='cluster_id', aggfunc='count', fill_value=0
    )
print(tbl)
cluster_id   0   1   2
species
setosa      50   0   0
versicolor   0  48   2
virginica    0  14  36

lambda式とapplyを利用して正解率を求めます.setosa の正解率は 100%,versicolor は 77.4%,virginica は 94.7% であることがわかりました.


tbl = df.pivot_table(
    'sepal_length', index='species', columns='cluster_id', aggfunc='count', fill_value=0
    ).apply(lambda x:x/sum(x), axis=0)
print(tbl)
cluster_id    0         1         2
species
setosa      1.0  0.000000  0.000000
versicolor  0.0  0.774194  0.052632
virginica   0.0  0.225806  0.947368

目次に戻る