Machine Learning (6) - 關於 Logistic Regression (Multiclass Classification) 的小練習

Rachel發表於2019-04-14

Machine Learning (6) - 關於 Logistic Regression (Multiclass Classification) 的小練習

Iris flower data set 是關於一種花的資料集. 這種花有三個品種, 分別是 setosa, virginica 和 versicolor. 每朵花都有兩種花瓣(sepals 和 petals).早在 20 世紀 30 年代, 一位學者對每個品種收集了 50 個樣本, 分別測量兩種花瓣的長度和寬度, 最終形成了一個有 150 條資料的資料集. 這個資料集被廣泛用於機器學習的初學者做資料分析的練習.

import pandas as pd
import matplotlib.pyplot as plt
// 引入 iris 資料集
from sklearn.datasets import load_iris
iris = load_iris()

// 檢視 iris 資料集的屬性
dir(iris)
['DESCR', 'data', 'feature_names', 'filename', 'target', 'target_names']

// 檢視 iris 資料集的前5條資料
iris.data[0:5]
// 輸出, 分別是每朵花的每種花瓣的長度和寬度
array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2]])

// 檢視 iris 資料集的屬性名稱
iris.feature_names
// 輸出
['sepal length (cm)',
 'sepal width (cm)',
 'petal length (cm)',
 'petal width (cm)']

iris.target
// 輸出
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

// 這裡就是 iris 花的三個品種的名字, 應該是分別對應了 target 值的 0, 1, 2
iris.target_names
//輸出
array(['setosa', 'versicolor', 'virginica'], dtype='<U10')

// 把資料集拆分為訓練資料和測試資料
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2)
len(X_train) // 120
len(X_test) // 30

// 訓練模型
from sklearn.linear_model import LogisticRegression
model = LogisticRegression()
model.fit(X_train, y_train)

// 檢視模型準確度
model.score(X_test, y_test) // 0.9

// 透過模型進行預測
model.predict([[4.4, 3., 1.6, 0.9]])
// 輸出
array([0])

想要更加細緻地瞭解誤差的位置, 可以透過 confusion_matrix 類實現:

// 透過模型預測的值
y_predicted = model.predict(X_test)

// 引入 confusion_matrix 包
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test, y_predicted)
cm
// 輸出
array([[ 7,  0,  0],
       [ 0,  8,  2],
       [ 0,  0, 13]])

// 為了將上面的輸出視覺化更強, 引入 seaborn 包     
import seaborn as sn
plt.figure(figsize = (10, 7))
sn.heatmap(cm, annot=True)
plt.xlabel('Predicted')
plt.ylabel('Truth')

Machine Learning (6) - 關於 Logistic Regression (Multiclass Classification) 的小練習

本作品採用《CC 協議》,轉載必須註明作者和本文連結

相關文章