機器學習演算法(九): 基於線性判別模型的LDA手寫數字分類識別

✨汀、發表於2023-03-29

1.機器學習演算法(九): 基於線性判別模型的LDA手寫數字分類識別

1.1 LDA演算法簡介和應用

線性判別模型(LDA)在模式識別領域(比如人臉識別等圖形影像識別領域)中有非常廣泛的應用。LDA是一種監督學習的降維技術,也就是說它的資料集的每個樣本是有類別輸出的。這點和PCA不同。PCA是不考慮樣本類別輸出的無監督降維技術。LDA的思想可以用一句話概括,就是“投影后類內方差最小,類間方差最大”。我們要將資料在低維度上進行投影,投影后希望每一種類別資料的投影點儘可能的接近,而不同類別的資料的類別中心之間的距離儘可能的大。即:將資料投影到維度更低的空間中,使得投影后的點,會形成按類別區分,一簇一簇的情況,相同類別的點,將會在投影后的空間中更接近方法。

LDA演算法的一個目標是使得不同類別之間的距離越遠越好,同一類別之中的距離越近越好。那麼不同類別之間的距離越遠越好,我們是可以理解的,就是越遠越好區分。同時,協方差不僅是反映了變數之間的相關性,同樣反映了多維樣本分佈的離散程度(一維樣本使用方差),協方差越大(對於負相關來說是絕對值越大),表示資料的分佈越分散。所以上面的“欲使同類樣例的投影點儘可能接近,可以讓同類樣本點的協方差矩陣儘可能小”就可以理解了。

$J(w)=\frac{w^T|\mu_1 - \mu_2~|2}{s2_1+s2_2}$

如上述公式 $J(w)$ 所示,分子為投影資料後的均值只差,分母為方差之後,LDA的目的就是使得 $J$ 值最大化,那麼可以理解為最大化分子,即使得類別之間的距離越遠,同時最小化分母,使得每個類別內部的方差越小,這樣就能使得每個類類別的資料可以在投影矩陣 $w$ 的對映下,分的越開。

需要注意的是,LDA模型適用於線性可分資料,對於上述實戰中用到的MNIST手寫資料(其實是分線性的),但是依然可以取得較好的分類效果;但在以後的實戰中需要注意LDA在非線性可分資料上的謹慎使用。

1.2.演算法應用

LDA在模式識別領域(比如人臉識別,艦艇識別等圖形影像識別領域)中有非常廣泛的應用,因此我們有必要了解一下它的演算法原理。不過在學習LDA之前,我們有必要將其與自然語言處理領域中的LDA區分開,在自然語言處理領域,LDA是隱含狄利克雷分佈(Latent DIrichlet Allocation,簡稱LDA),它是一種處理文件的主題模型,我們本文討論的是線性判別分析,因此後面所說的LDA均為線性判別分析。

LDA除了可以用於降維以外,還可以用於分類。一個常見的LDA分類基本思想是假設各個類別的樣本資料符合高斯分佈,這樣利用LDA進行投影后,可以利用極大似然估計計算各個類別投影資料的均值和方差,進而得到該類別高斯分佈的機率密度函式。當一個新的樣本到來後,我們可以將它投影,然後將投影后的樣本特徵分別帶入各個類別的高斯分佈機率密度函式,計算它屬於這個類別的機率,最大的機率對應的類別即為預測類別。

2.相關流程

  • 掌握LDA演算法基本原理
  • 掌握利用LDA進行程式碼實戰
  • Part 1 Demo實踐

    • Step1:庫函式匯入
    • Step2:模型訓練
    • Step3:模型引數檢視
    • Step4:資料和模型視覺化
    • Step5:模型預測
  • Part 2 基於LDA手寫數字分類實踐

    • Step1:庫函式匯入
    • Step2:資料讀取/載入
    • Step3:資料資訊簡單檢視與視覺化
    • Step4:利用LDA在手寫數字上進行訓練和預測

3.程式碼實戰

3.1 Demo實踐

  • Step1:庫函式匯入
# 基礎陣列運算庫匯入
import numpy as np 
# 畫相簿匯入
import matplotlib.pyplot as plt 
# 匯入三維顯示工具
from mpl_toolkits.mplot3d import Axes3D
# 匯入LDA模型
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
# 匯入demo資料製作方法
from sklearn.datasets import make_classification
  • Step2:模型訓練
# 製作四個類別的資料,每個類別100個樣本
X, y = make_classification(n_samples=1000, n_features=3, n_redundant=0,
                           n_classes=4, n_informative=2, n_clusters_per_class=1,
                           class_sep=3, random_state=10)
# 將四個類別的資料進行三維顯示
fig = plt.figure()
ax = Axes3D(fig, rect=[0, 0, 1, 1], elev=20, azim=20)
ax.scatter(X[:, 0], X[:, 1], X[:, 2], marker='o', c=y)
plt.show()

# 建立 LDA 模型
lda = LinearDiscriminantAnalysis()
# 進行模型訓練
lda.fit(X, y)
LinearDiscriminantAnalysis()
  • Step3:模型引數檢視
# 檢視 LDA 模型的引數
lda.get_params()
{'covariance_estimator': None,
 'n_components': None,
 'priors': None,
 'shrinkage': None,
 'solver': 'svd',
 'store_covariance': False,
 'tol': 0.0001}
  • Step4:資料和模型視覺化
# 進行模型預測
X_new = lda.transform(X)
# 視覺化預測資料
plt.scatter(X_new[:, 0], X_new[:, 1], marker='o', c=y)
plt.show()

  • Step5:模型預測
# 進行新的測試資料測試
a = np.array([[-1, 0.1, 0.1]])
print(f"{a} 類別是: ", lda.predict(a))
print(f"{a} 類別機率分別是: ", lda.predict_proba(a))

a = np.array([[-12, -100, -91]])
print(f"{a} 類別是: ", lda.predict(a))
print(f"{a} 類別機率分別是: ", lda.predict_proba(a))

a = np.array([[-12, -0.1, -0.1]])
print(f"{a} 類別是: ", lda.predict(a))
print(f"{a} 類別機率分別是: ", lda.predict_proba(a))

a = np.array([[0.1, 90.1, 9.1]])
print(f"{a} 類別是: ", lda.predict(a))
print(f"{a} 類別機率分別是: ", lda.predict_proba(a))
[[-1.   0.1  0.1]] 類別是:  [0]
[[-1.   0.1  0.1]] 類別機率分別是:  [[9.37611354e-01 1.88760664e-05 3.36891510e-02 2.86806189e-02]]
[[ -12 -100  -91]] 類別是:  [1]
[[ -12 -100  -91]] 類別機率分別是:  [[1.08769337e-028 1.00000000e+000 1.54515810e-221 9.05666876e-183]]
[[-12.   -0.1  -0.1]] 類別是:  [2]
[[-12.   -0.1  -0.1]] 類別機率分別是:  [[1.60268201e-07 1.46912978e-39 9.99999840e-01 3.57001075e-28]]
[[ 0.1 90.1  9.1]] 類別是:  [3]
[[ 0.1 90.1  9.1]] 類別機率分別是:  [[8.42065614e-08 9.45021749e-11 8.63060269e-02 9.13693889e-01]]

3.2 Part 2 基於LDA手寫數字分類實踐

  • Step1:庫函式匯入
# 匯入手寫資料集 MNIST
from sklearn.datasets import load_digits
# 匯入訓練集分割方法
from sklearn.model_selection import train_test_split
# 匯入LDA模型
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
# 匯入預測指標計算函式和混淆矩陣計算函式
from sklearn.metrics import classification_report, confusion_matrix
# 匯入繪圖包
import seaborn as sns
import matplotlib
  • Step2:資料讀取/載入
# 匯入MNIST資料集
mnist = load_digits()

# 檢視資料集資訊
print('The Mnist dataeset:\n',mnist)

# 分割資料為訓練集和測試集
x, test_x, y, test_y = train_test_split(mnist.data, mnist.target, test_size=0.1, random_state=2)
The Mnist dataeset:
 {'data': array([[ 0.,  0.,  5., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ..., 10.,  0.,  0.],
       [ 0.,  0.,  0., ..., 16.,  9.,  0.],
       ...,
       [ 0.,  0.,  1., ...,  6.,  0.,  0.],
       [ 0.,  0.,  2., ..., 12.,  0.,  0.],
       [ 0.,  0., 10., ..., 12.,  1.,  0.]]), 'target': array([0, 1, 2, ..., 8, 9, 8]), 'frame': None, 'feature_names': ['pixel_0_0', 'pixel_0_1', 'pixel_0_2', 'pixel_0_3', 'pixel_0_4', 'pixel_0_5', 'pixel_0_6', 'pixel_0_7', 'pixel_1_0', 'pixel_1_1', 'pixel_1_2', 'pixel_1_3', 'pixel_1_4', 'pixel_1_5', 'pixel_1_6', 'pixel_1_7', 'pixel_2_0', 'pixel_2_1', 'pixel_2_2', 'pixel_2_3', 'pixel_2_4', 'pixel_2_5', 'pixel_2_6', 'pixel_2_7', 'pixel_3_0', 'pixel_3_1', 'pixel_3_2', 'pixel_3_3', 'pixel_3_4', 'pixel_3_5', 'pixel_3_6', 'pixel_3_7', 'pixel_4_0', 'pixel_4_1', 'pixel_4_2', 'pixel_4_3', 'pixel_4_4', 'pixel_4_5', 'pixel_4_6', 'pixel_4_7', 'pixel_5_0', 'pixel_5_1', 'pixel_5_2', 'pixel_5_3', 'pixel_5_4', 'pixel_5_5', 'pixel_5_6', 'pixel_5_7', 'pixel_6_0', 'pixel_6_1', 'pixel_6_2', 'pixel_6_3', 'pixel_6_4', 'pixel_6_5', 'pixel_6_6', 'pixel_6_7', 'pixel_7_0', 'pixel_7_1', 'pixel_7_2', 'pixel_7_3', 'pixel_7_4', 'pixel_7_5', 'pixel_7_6', 'pixel_7_7'], 'target_names': array([0, 1, 2, 3, 4, 5, 6, 7,
        [ 0.,  0., 13., ..., 15.,  5.,  0.],
        [ 0.,  3., 15., ..., 11.,  8.,  0.],
        ...,
        [ 0.,  4., 11., ..., 12.,  7.,  0.],
        [ 0.,  2., 14., ..., 12.,  0.,  0.],
        [ 0.,  0.,  6., ...,  0.,  0.,  0.]],

       [[ 0.,  0.,  0., ...,  5.,  0.,  0.],
        [ 0.,  0.,  0., ...,  9.,  0.,  0.],
        [ 0.,  0.,  3., ...,  6.,  0.,  0.],
        ...,
        [ 0.,  0.,  1., ...,  6.,  0.,  0.],
        [ 0.,  0.,  1., ...,  6.,  0.,  0.],
        [ 0.,  0.,  0., ..., 10.,  0.,  0.]],

       [[ 0.,  0.,  0., ..., 12.,  0.,  0.],
        [ 0.,  0.,  3., ..., 14.,  0.,  0.],
        [ 0.,  0.,  8., ..., 16.,  0.,  0.],
        ...,
        [ 0.,  9., 16., ...,  0.,  0.,  0.],
        [ 0.,  3., 13., ..., 11.,  5.,  0.],
        [ 0.,  0.,  0., ..., 16.,  9.,  0.]],

       ...,

       [[ 0.,  0.,  1., ...,  1.,  0.,  0.],
        [ 0.,  0., 13., ...,  2.,  1.,  0.],
        [ 0.,  0., 16., ..., 16.,  5.,  0.],
        ...,
        [ 0.,  0., 16., ..., 15.,  0.,  0.],
        [ 0.,  0., 15., ..., 16.,  0.,  0.],
        [ 0.,  0.,  2., ...,  6.,  0.,  0.]],

       [[ 0.,  0.,  2., ...,  0.,  0.,  0.],
        [ 0.,  0., 14., ..., 15.,  1.,  0.],
        [ 0.,  4., 16., ..., 16.,  7.,  0.],
        ...,
        [ 0.,  0.,  0., ..., 16.,  2.,  0.],
        [ 0.,  0.,  4., ..., 16.,  2.,  0.],
        [ 0.,  0.,  5., ..., 12.,  0.,  0.]],

       [[ 0.,  0., 10., ...,  1.,  0.,  0.],
        [ 0.,  2., 16., ...,  1.,  0.,  0.],
        [ 0.,  0., 15., ..., 15.,  0.,  0.],
        ...,
        [ 0.,  4., 16., ..., 16.,  6.,  0.],
        [ 0.,  8., 16., ..., 16.,  8.,  0.],
        [ 0.,  1.,  8., ..., 12.,  1.,  0.]]]), 'DESCR': ".. _digits_dataset:\n\nOptical recognition of handwritten digits dataset\n--------------------------------------------------\n\n**Data Set Characteristics:**\n\n    :Number of Instances: 1797\n    :Number of Attributes: 64\n    :Attribute Information: 8x8 image of integer pixels in the range 0..16.\n    :Missing Attribute Values: None\n    :Creator: E. Alpaydin (alpaydin '@' boun.edu.tr)\n    :Date: July; 1998\n\nThis is a copy of the test set of the UCI ML hand-written digits datasets\nhttps://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits\n\nThe data set contains images of hand-written digits: 10 classes where\neach class refers to a digit.\n\nPreprocessing programs made available by NIST were used to extract\nnormalized bitmaps of handwritten digits from a preprinted form. From a\ntotal of 43 people, 30 contributed to the training set and different 13\nto the test set. 32x32 bitmaps are divided into
  • Step3:資料資訊簡單檢視與視覺化
## 輸出示例影像
images = range(0,9)

plt.figure(dpi=100)
for i in images:
    plt.subplot(330 + 1 + i)
    plt.imshow(x[i].reshape(8, 8), cmap = matplotlib.cm.binary,interpolation="nearest")
# show the plot

plt.show()

  • Step4:利用LDA在手寫數字上進行訓練和預測
# 建立 LDA 模型
m_lda = LinearDiscriminantAnalysis()
# 進行模型訓練
m_lda.fit(x, y)
LinearDiscriminantAnalysis()
# 進行模型預測
x_new = m_lda.transform(x)
# 視覺化預測資料
plt.scatter(x_new[:, 0], x_new[:, 1], marker='o', c=y)
plt.title('MNIST with LDA Model')
plt.show()

# 進行測試集資料的類別預測
y_test_pred = m_lda.predict(test_x)
print("測試集的真實標籤:\n", test_y)
print("測試集的預測標籤:\n", y_test_pred)
測試集的真實標籤:
 [4 0 9 1 4 7 1 5 1 6 6 7 6 1 5 5 4 6 2 7 4 6 4 1 5 2 9 5 4 6 5 6 3 4 0 9 9
 8 4 6 8 8 5 7 9 6 9 6 1 3 0 1 9 7 3 3 1 1 8 8 9 8 5 4 4 7 3 5 8 4 3 1 3 8
 7 3 3 0 8 7 2 8 5 3 8 7 6 4 6 2 2 0 1 1 5 3 5 7 6 8 2 2 6 4 6 7 3 7 3 9 4
 7 0 3 5 8 5 0 3 9 2 7 3 2 0 8 1 9 2 1 9 1 0 3 4 3 0 9 3 2 2 7 3 1 6 7 2 8
 3 1 1 6 4 8 2 1 8 4 1 3 1 1 9 5 4 8 7 4 8 9 5 7 6 9 0 0 4 0 0 4]
測試集的預測標籤:
 [4 0 9 1 8 7 1 5 1 6 6 7 6 2 5 5 8 6 2 7 4 6 4 1 5 2 9 5 4 6 5 6 3 4 0 9 9
 8 4 6 8 1 5 7 9 6 9 6 1 3 0 1 9 7 3 3 1 1 8 8 9 8 5 8 4 9 3 5 8 4 3 9 3 8
 7 3 3 0 8 7 2 8 5 3 8 7 6 4 6 2 2 0 1 1 5 3 5 7 1 8 2 2 6 4 6 7 3 7 3 9 4
 7 0 3 5 1 5 0 3 9 2 7 3 2 0 8 1 9 2 1 9 9 0 3 4 3 0 8 3 2 2 7 3 1 6 7 2 8
 3 1 1 6 4 8 2 1 8 4 1 3 1 1 9 5 4 9 7 4 8 9 5 7 6 9 6 0 4 0 0 9]
# 進行預測結果指標統計 統計每一類別的預測準確率、召回率、F1分數
print(classification_report(test_y, y_test_pred))
              precision    recall  f1-score   support

           0       1.00      0.93      0.96        14
           1       0.86      0.86      0.86        22
           2       0.93      1.00      0.97        14
           3       1.00      1.00      1.00        22
           4       1.00      0.81      0.89        21
           5       1.00      1.00      1.00        16
           6       0.94      0.94      0.94        18
           7       1.00      0.94      0.97        18
           8       0.80      0.84      0.82        19
           9       0.75      0.94      0.83        16

    accuracy                           0.92       180
   macro avg       0.93      0.93      0.93       180
weighted avg       0.93      0.92      0.92       180
# 計算混淆矩陣
C2 = confusion_matrix(test_y, y_test_pred)
# 打混淆矩陣
print(C2)

# 將混淆矩陣以熱力圖的防線顯示
sns.set()
f, ax = plt.subplots()
# 畫熱力圖
sns.heatmap(C2, cmap="YlGnBu_r", annot=True, ax=ax)  
# 標題 
ax.set_title('confusion matrix')
# x軸為預測類別
ax.set_xlabel('predict')  
# y軸實際類別
ax.set_ylabel('true')  
plt.show()
[[13  0  0  0  0  0  1  0  0  0]
 [ 0 19  1  0  0  0  0  0  0  2]
 [ 0  0 14  0  0  0  0  0  0  0]
 [ 0  0  0 22  0  0  0  0  0  0]
 [ 0  0  0  0 17  0  0  0  3  1]
 [ 0  0  0  0  0 16  0  0  0  0]
 [ 0  1  0  0  0  0 17  0  0  0]
 [ 0  0  0  0  0  0  0 17  0  1]
 [ 0  2  0  0  0  0  0  0 16  1]
 [ 0  0  0  0  0  0  0  0  1 15]]

4.總結

LDA演算法的主要優點:

  1. 在降維過程中可以使用類別的先驗知識經驗,而像PCA這樣的無監督學習則無法使用類別先驗知識;
  2. LDA在樣本分類資訊依賴均值而不是方差的時候,比PCA之類的演算法較優。

LDA演算法的主要缺點:

  1. LDA不適合對非高斯分佈樣本進行降維,PCA也有這個問題
  2. LDA降維最多降到類別數 k-1 的維數,如果我們降維的維度大於 k-1,則不能使用 LDA。當然目前有一些LDA的進化版演算法可以繞過這個問題
  3. LDA在樣本分類資訊依賴方差而不是均值的時候,降維效果不好
  4. LDA可能過度擬合資料,

本專案連結:https://www.heywhale.com/home/column/64141d6b1c8c8b518ba97dcc

參考連結:https://tianchi.aliyun.com/course/278/3426


本人最近打算整合ML、DRL、NLP等相關領域的體系化專案課程,方便入門同學快速掌握相關知識。宣告:部分專案為網路經典專案方便大家快速學習,後續會不斷增添實戰環節(比賽、論文、現實應用等)。

  • 對於機器學習這塊規劃為:基礎入門機器學習演算法--->簡單專案實戰--->資料建模比賽----->相關現實中應用場景問題解決。一條路線幫助大家學習,快速實戰。
  • 對於深度強化學習這塊規劃為:基礎單智慧演算法教學(gym環境為主)---->主流多智慧演算法教學(gym環境為主)---->單智慧多智慧題實戰(論文復現偏業務如:無人機最佳化排程、電力資源排程等專案應用)
  • 自然語言處理相關規劃:除了單點演算法技術外,主要圍繞知識圖譜構建進行:資訊抽取相關技術(含智慧標註)--->知識融合---->知識推理---->圖譜應用

上述對於你掌握後的期許:

  1. 對於ML,希望你後續可以亂殺數學建模相關比賽(參加就獲獎保底,top還是難的需要鑽研)
  2. 可以實際解決現實中一些最佳化排程問題,而非停留在gym環境下的一些遊戲demo玩玩。(更深層次可能需要自己鑽研了,難度還是很大的)
  3. 掌握可知識圖譜全流程構建其中各個重要環節演算法,包含圖資料庫相關知識。

這三塊領域耦合情況比較大,後續會透過比如:搜尋推薦系統整個專案進行耦合,各項演算法都會耦合在其中。舉例:知識圖譜就會用到(圖演算法、NLP、ML相關演算法),搜尋推薦系統(除了該領域召回粗排精排重排混排等演算法外,還有強化學習、知識圖譜等耦合在其中)。餅畫的有點大,後面慢慢實現。

相關文章