Machine Learning (12) - Support Vector Machine (SVM)

Rachel發表於2019-06-10

引言

Support Vector Machine 簡稱 SVM, 是一種非常常用的分類演算法。

下圖是著名的 Iris 花, 相信大家已經不陌生了。它一共有 3 個品種,由 petal 花瓣的寬度和長度和 sepal 花瓣的寬度和長度 4 個特徵,決定了每朵花的品種。

Machine Learning (12) - Support Vector Machine (SVM)

首先下面只是簡單地從 Petal 的 Length 和 Width 兩個維度做劃分,可以看到明顯地分出 Setosa 和 Versicolor 兩個品種的花。然而,兩種花的分界線要如何畫呢,看似有很多種畫法,哪條才是最佳的呢?

Machine Learning (12) - Support Vector Machine (SVM)

當然是選一條離每個離每個臨界點都最遠的線是最佳方案。這些邊界點被稱為 Support Vectors, 而 SVM 就可以幫我們畫出這條最佳的分割線。

Machine Learning (12) - Support Vector Machine (SVM)

SVM 演算法的作用是對由 n 個維度構成的資料進行劃分, 也就是做資料分類, 比如對於二維屬性的資料, 想將兩類資料分開, 就需要一條線, 對於三維屬性的資料, 就需要一個面, 那麼 n 維屬性呢(這在實際應用中是非常常見的, 比如 iris 花的種類就是分別由兩種花瓣各自的長寬值決定的, 也就是 4 個維度), 我們把它稱作做超面.

Machine Learning (12) - Support Vector Machine (SVM)

在 SVM 模型中, 線和麵的形狀都是可以通過引數來調節, 從而達到最大的準確度。比如下圖中的 gamma 和 regularization 都是在訓練模型時可調的引數。

Machine Learning (12) - Support Vector Machine (SVM)

Machine Learning (12) - Support Vector Machine (SVM)

比如有如下的一個資料集:

Machine Learning (12) - Support Vector Machine (SVM)
我們發現,如果單從 x 軸 和 y 軸的維度,很難對資料做劃分,所以可以增加一個 z 軸,將資料做一下轉換,從而可以輕鬆地將資料分類:

Machine Learning (12) - Support Vector Machine (SVM)

正文

引入並分析 The Iris Dataset 資料

import pandas as pd
from sklearn.datasets import load_iris
iris = load_iris()
iris

輸出:

這裡只擷取了 data 屬性的部分資料(在 jupyter 中向下拉, 還可以看到 iris 的其他屬性), 可以看出 data 是一個二維陣列.

利用 pandas 將這個二維陣列轉成 dataframe, 設定列名為它的 feature_names 屬性, 也就是兩種花瓣的長和寬

df = pd.DataFrame(iris.data, columns = iris.feature_names)
df.head()

輸出:

Machine Learning (12) - Support Vector Machine (SVM)

把 iris 的 target 屬性加入 dataframe:

df['target'] = iris.target
df.head()

Machine Learning (12) - Support Vector Machine (SVM)

檢視 target 為 2 的所有資料:

df[df.target == 2].head()

Machine Learning (12) - Support Vector Machine (SVM)

給這個 dataframe 增加 flower_name 列(也就是花的種類):

df['flower_name'] = df.target.apply(lambda x: iris.target_names[x])
df.head()

Machine Learning (12) - Support Vector Machine (SVM)

以上就是我們準備好的, 可以用於分析的資料, 下面根據花的品種, 將資料分成三份:

from matplotlib import pyplot as plt
%matplotlib inline
df0 = df[df.target == 0]
df1 = df[df.target == 1]
df2 = df[df.target == 2]
df0.head()

Machine Learning (12) - Support Vector Machine (SVM)

以圖表的形式, 同時取第一種花和第三種花的 sepal 花瓣的長和寬輸出:

plt.scatter(df0['sepal length (cm)'], df0['sepal width (cm)'],color='green', marker='+')
plt.scatter(df2['sepal length (cm)'], df2['sepal width (cm)'],color='red', marker='+')
plt.xlabel('sepal length (cm)')
plt.ylabel('sepal width (cm)')

從下圖中, 我們可以非常直觀且明確地區分出這兩種花

Machine Learning (12) - Support Vector Machine (SVM)

下面把第二種花的值也加上:

plt.scatter(df0['sepal length (cm)'], df0['sepal width (cm)'],color='green', marker='+')
plt.scatter(df1['sepal length (cm)'], df1['sepal width (cm)'],color='blue', marker='+')
plt.scatter(df2['sepal length (cm)'], df2['sepal width (cm)'],color='red', marker='+')
plt.xlabel('sepal length (cm)')
plt.ylabel('sepal width (cm)')

我們發現每種花之間的分割線變得沒那麼容易畫了, 尤其是在第二種(藍)和第三種(紅)之間, 出現了混合現象, 也就是說, 在指定 sepal 花瓣的長寬值的情況下, 我們並不容易區分出這是哪種花.

Machine Learning (12) - Support Vector Machine (SVM)

不過, 我們可以根據現有的資料來訓練 SVM 模型, 通過調節引數, 建立一個比較準確的模型.

訓練模型

from sklearn.model_selection import train_test_split
X = df.drop(['target', 'flower_name'], axis='columns')
X.head()

Machine Learning (12) - Support Vector Machine (SVM)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

// 引入模型
from sklearn.svm import SVC
model = SVC(C = 100) // 這裡就可以傳參, 來微調模型的準確度
model.fit(X_train, y_train)

//輸出, 這些引數都是可以用來微調模型的準確度的
SVC(C=100, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape='ovr', degree=3, gamma='auto_deprecated',
  kernel='rbf', max_iter=-1, probability=False, random_state=None,
  shrinking=True, tol=0.001, verbose=False)

// 測試模型準確度
model.score(X_test, y_test)

//輸出
0.9666666666666667

相關文章