引言
上一節我們重點關注了 binary classification, 也就是預測結果非是則非,這一節我們來學習多種分類值的模型。
正文
我們通過識別手寫數字的例子, 學習 Multiclass Classification。需要藉助 The Digit Dataset 的資料集, 它是由 1797 個 8*8 的影象組成的。如下圖所示, 每一個影象都是一個手寫數字。也就是說, 我們通過這個資料集, 可以獲得 1797 條資料用來訓練模型。
與 binary classification 的不同之處在於, 這裡獲得的結果不是 yes or no 的二選一, 而是數字從 0-9 的 10 種分類.
第一步:引入資料, 並對其特性做簡單的瞭解
%matplotlib inline
import matplotlib.pyplot as plt
// 引入 load_digits 資料集
from sklearn.datasets import load_digits
digits = load_digits()
// 檢視這個資料集都有哪些屬性
dir(digits)
// 輸出
['DESCR', 'data', 'images', 'target', 'target_names']
digits.data
// 輸出
array([[ 0., 0., 5., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 10., 0., 0.],
[ 0., 0., 0., ..., 16., 9., 0.],
...,
[ 0., 0., 1., ..., 6., 0., 0.],
[ 0., 0., 2., ..., 12., 0., 0.],
[ 0., 0., 10., ..., 12., 1., 0.]])
// 檢視這個資料集的整體結構, 從輸出可以看到這是一個二維陣列, 一共是有 1797 條資料, 每條資料由 64 個數字組層
digits.data.shape // 輸出 (1797, 64)
// 檢視二維資料的第一個, 也就是代表了第一個數字
digits.data[0]
// 輸出
array([ 0., 0., 5., 13., 9., 1., 0., 0., 0., 0., 13., 15., 10.,
15., 5., 0., 0., 3., 15., 2., 0., 11., 8., 0., 0., 4.,
12., 0., 0., 8., 8., 0., 0., 5., 8., 0., 0., 9., 8.,
0., 0., 4., 11., 0., 1., 12., 7., 0., 0., 2., 14., 5.,
10., 12., 0., 0., 0., 0., 6., 13., 10., 0., 0., 0.])
// 再來看下 images 屬性, 輸出前 4 個數字看看
for i in range(4):
plt.matshow(digits.images[i])
// target 屬性, 也就是 data 裡一堆數字到底是代表數字幾
digits.target[0:4] // 輸出 array([0, 1, 2, 3])
digits.target_names[0:4] // 輸出 array([0, 1, 2, 3])
第二步:準備訓練模型
// 按照 20% 測試資料的比例, 把資料集分為訓練資料和測試資料
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size = 0.2)
len(X_train) // 輸出 1437
len(X_test) // 輸出 360
// 訓練模型
from sklearn.linear_model import LogisticRegression
model = LogisticRegression()
model.fit(X_train, y_train)
// 檢視模型準確度
model.score(X_test, y_test) // 輸出 0.9777777777777777
以上, 就是藉助於 load_digits 資料集完成了對 multiclass 模型的訓練, 並且得出了精確度還比較準確. 如果想更加細緻地瞭解誤差的位置, 可以通過 confusion_matrix 類實現:
y_predicted = model.predict(X_test)
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test, y_predicted)
cm
// 輸出
array([[36, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[ 0, 45, 0, 0, 0, 0, 0, 0, 0, 0],
[ 0, 0, 36, 0, 0, 0, 0, 0, 0, 0],
[ 0, 0, 0, 39, 0, 0, 0, 0, 0, 0],
[ 0, 1, 0, 0, 27, 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0, 35, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0, 0, 36, 0, 0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 34, 0, 1],
[ 0, 2, 0, 0, 0, 1, 0, 0, 26, 1],
[ 0, 0, 0, 0, 0, 0, 0, 2, 0, 38]])
為了將上面的輸出視覺化更強, 引入 seaborn 包, 需要在終端安裝 pip3 install seaborn
import seaborn as sn
plt.figure(figsize = (10, 7))
sn.heatmap(cm, annot=True)
plt.xlabel('Predicted')
plt.ylabel('Truth')
簡單解釋這個圖表:
x 軸是預測的值, y 軸是實際的值
以左上角的 36 為例, 表示一共有 36 個數字 0 (看 y 軸), 而預測的也都是數字 0 (x 軸), 表示對 0 的預測沒有誤差.
再看最下面一行的 2, 表示有兩個數字 9 (y 軸) 被預測成了數字 7 (x 軸), 也就是說在數字 9 的預測上有兩個錯誤.
這個表上所有數字加起來正好是 360, 也就是測試資料的資料量.
這就非常直觀地看出我們這個資料模型的準確度表現了.