Machine Learning (8) - Logistic Regression (Multiclass Classification)

Rachel發表於2019-06-07

引言

Machine Learning (8) - Logistic Regression (Multiclass Classification)

上一節我們重點關注了 binary classification, 也就是預測結果非是則非,這一節我們來學習多種分類值的模型。

正文

Machine Learning (8) - Logistic Regression (Multiclass Classification)

我們通過識別手寫數字的例子, 學習 Multiclass Classification。需要藉助 The Digit Dataset 的資料集, 它是由 1797 個 8*8 的影象組成的。如下圖所示, 每一個影象都是一個手寫數字。也就是說, 我們通過這個資料集, 可以獲得 1797 條資料用來訓練模型。

Machine Learning (8) - Logistic Regression (Multiclass Classification)

與 binary classification 的不同之處在於, 這裡獲得的結果不是 yes or no 的二選一, 而是數字從 0-9 的 10 種分類.

第一步:引入資料, 並對其特性做簡單的瞭解

%matplotlib inline
import matplotlib.pyplot as plt

// 引入 load_digits 資料集
from sklearn.datasets import load_digits
digits = load_digits()

// 檢視這個資料集都有哪些屬性
dir(digits)
// 輸出
['DESCR', 'data', 'images', 'target', 'target_names']

digits.data
// 輸出
array([[ 0.,  0.,  5., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ..., 10.,  0.,  0.],
       [ 0.,  0.,  0., ..., 16.,  9.,  0.],
       ...,
       [ 0.,  0.,  1., ...,  6.,  0.,  0.],
       [ 0.,  0.,  2., ..., 12.,  0.,  0.],
       [ 0.,  0., 10., ..., 12.,  1.,  0.]])

// 檢視這個資料集的整體結構, 從輸出可以看到這是一個二維陣列, 一共是有 1797 條資料, 每條資料由 64 個數字組層
digits.data.shape  // 輸出 (1797, 64)

// 檢視二維資料的第一個, 也就是代表了第一個數字
digits.data[0]
// 輸出
array([ 0.,  0.,  5., 13.,  9.,  1.,  0.,  0.,  0.,  0., 13., 15., 10.,
       15.,  5.,  0.,  0.,  3., 15.,  2.,  0., 11.,  8.,  0.,  0.,  4.,
       12.,  0.,  0.,  8.,  8.,  0.,  0.,  5.,  8.,  0.,  0.,  9.,  8.,
        0.,  0.,  4., 11.,  0.,  1., 12.,  7.,  0.,  0.,  2., 14.,  5.,
       10., 12.,  0.,  0.,  0.,  0.,  6., 13., 10.,  0.,  0.,  0.])

// 再來看下 images 屬性, 輸出前 4 個數字看看      
for i in range(4):
    plt.matshow(digits.images[i])

Machine Learning (8) - Logistic Regression (Multiclass Classification)

Machine Learning (8) - Logistic Regression (Multiclass Classification)

// target 屬性, 也就是 data 裡一堆數字到底是代表數字幾
digits.target[0:4] // 輸出 array([0, 1, 2, 3])
digits.target_names[0:4] // 輸出 array([0, 1, 2, 3])

第二步:準備訓練模型

// 按照 20% 測試資料的比例, 把資料集分為訓練資料和測試資料
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size = 0.2)
len(X_train) // 輸出 1437
len(X_test) // 輸出 360

// 訓練模型
from sklearn.linear_model import LogisticRegression
model = LogisticRegression()
model.fit(X_train, y_train)

// 檢視模型準確度
model.score(X_test, y_test) // 輸出 0.9777777777777777

以上, 就是藉助於 load_digits 資料集完成了對 multiclass 模型的訓練, 並且得出了精確度還比較準確. 如果想更加細緻地瞭解誤差的位置, 可以通過 confusion_matrix 類實現:

y_predicted = model.predict(X_test)
from sklearn.metrics import confusion_matrix

cm = confusion_matrix(y_test, y_predicted)
cm
// 輸出
array([[36,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0, 45,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0, 36,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0, 39,  0,  0,  0,  0,  0,  0],
       [ 0,  1,  0,  0, 27,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0, 35,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0, 36,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0, 34,  0,  1],
       [ 0,  2,  0,  0,  0,  1,  0,  0, 26,  1],
       [ 0,  0,  0,  0,  0,  0,  0,  2,  0, 38]])

為了將上面的輸出視覺化更強, 引入 seaborn 包, 需要在終端安裝 pip3 install seaborn

import seaborn as sn
plt.figure(figsize = (10, 7))
sn.heatmap(cm, annot=True)
plt.xlabel('Predicted')
plt.ylabel('Truth')

Machine Learning (8) - Logistic Regression (Multiclass Classification)

簡單解釋這個圖表:
x 軸是預測的值, y 軸是實際的值
以左上角的 36 為例, 表示一共有 36 個數字 0 (看 y 軸), 而預測的也都是數字 0 (x 軸), 表示對 0 的預測沒有誤差.
再看最下面一行的 2, 表示有兩個數字 9 (y 軸) 被預測成了數字 7 (x 軸), 也就是說在數字 9 的預測上有兩個錯誤.
這個表上所有數字加起來正好是 360, 也就是測試資料的資料量.
這就非常直觀地看出我們這個資料模型的準確度表現了.

相關文章