引言
Iris flower data set 是關於一種花的資料集。這種花有三個品種, 分別是 setosa, virginica 和 versicolor。每朵花都有兩種花瓣(sepals 和 petals)。早在 20 世紀 30 年代, 一位學者對每個品種收集了 50 個樣本, 分別測量兩種花瓣的長度和寬度, 最終形成了一個有 150 條資料的資料集。這個資料集被廣泛用於機器學習的初學者做資料分析的練習。
正文
1. 引入 iris 資料集
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
iris = load_iris()
2. 檢視 iris 資料集的相關屬性
dir(iris)
['DESCR', 'data', 'feature_names', 'filename', 'target', 'target_names']
// 檢視 iris 資料集的前 5 條資料
iris.data[0:5]
// 輸出, 分別是每朵花的每種花瓣的長度和寬度
array([[5.1, 3.5, 1.4, 0.2],
[4.9, 3. , 1.4, 0.2],
[4.7, 3.2, 1.3, 0.2],
[4.6, 3.1, 1.5, 0.2],
[5. , 3.6, 1.4, 0.2]])
// 檢視 iris 資料集的屬性名稱
iris.feature_names
// 輸出
['sepal length (cm)',
'sepal width (cm)',
'petal length (cm)',
'petal width (cm)']
iris.target
// 輸出
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
// 這裡就是 iris 花的三個品種的名字, 應該是分別對應了 target 值的 0, 1, 2
iris.target_names
//輸出
array(['setosa', 'versicolor', 'virginica'], dtype='<U10')
3. 把資料集拆分為訓練資料和測試資料
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2)
len(X_train) // 120
len(X_test) // 30
// 訓練模型
from sklearn.linear_model import LogisticRegression
model = LogisticRegression()
model.fit(X_train, y_train)
// 檢視模型準確度
model.score(X_test, y_test) // 0.9
// 通過模型進行預測
model.predict([[4.4, 3., 1.6, 0.9]])
// 輸出
array([0])
4. 想要更加細緻地瞭解誤差的位置, 可以通過 confusion_matrix 類實現
// 通過模型預測的值
y_predicted = model.predict(X_test)
// 引入 confusion_matrix 包
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test, y_predicted)
cm
// 輸出
array([[ 7, 0, 0],
[ 0, 8, 2],
[ 0, 0, 13]])
// 為了將上面的輸出視覺化更強, 引入 seaborn 包
import seaborn as sn
plt.figure(figsize = (10, 7))
sn.heatmap(cm, annot=True)
plt.xlabel('Predicted')
plt.ylabel('Truth')