【火爐煉AI】機器學習014-用SVM構建非線性分類模型

煉丹老頑童發表於2018-08-07

【火爐煉AI】機器學習014-用SVM構建非線性分類模型

(本文所使用的Python庫和版本號: Python 3.5, Numpy 1.14, scikit-learn 0.19, matplotlib 2.2 )

支援向量機(Support Vector Machine,SVM)是一種常見的判別方法,其基本模型是在特徵空間上找到最佳的分離超平面,使得資料集上的正負樣本間隔最大。SVM用來解決二分類問題的有監督學習演算法,其可以解決線性問題,也可以通過引入核函式的方法來解決非線性問題。

本專案旨在使用SVM構建非線性分類模型來判別資料集的不同類別。


1. 準備資料集

首先來載入和檢視資料集的一些特性。如下程式碼載入資料集並檢視其基本資訊

# 準備資料集
data_path='E:\PyProjects\DataSet\FireAI/data_multivar_2_class.txt'
df=pd.read_csv(data_path,header=None)
# print(df.head()) # 沒有問題
print(df.info()) # 檢視資料資訊,確保沒有錯誤
dataset_X,dataset_y=df.iloc[:,:-1],df.iloc[:,-1]
# print(dataset_X.head())
# print(dataset_X.info())
# print(dataset_y.head()) # 檢查沒問題
dataset_X=dataset_X.values
dataset_y=dataset_y.values
複製程式碼

-------------------------------------輸---------出--------------------------------

<class 'pandas.core.frame.DataFrame'> RangeIndex: 300 entries, 0 to 299 Data columns (total 3 columns): 0 300 non-null float64 1 300 non-null float64 2 300 non-null int64 dtypes: float64(2), int64(1) memory usage: 7.1 KB None

--------------------------------------------完-------------------------------------

從上面的df.info()函式的結果可以看出,這個資料集有兩個特徵屬性(0,1列,連續的float型別),一個標記(2列,離散的int型,只有兩個類別),每一列都沒有缺失值。然後來看看這個資料集中資料點的分佈情況,如下圖所示:

# 資料集視覺化
def visual_2D_dataset(dataset_X,dataset_y):
    '''將二維資料集dataset_X和對應的類別dataset_y顯示在散點圖中'''
    assert dataset_X.shape[1]==2,'only support dataset with 2 features'
    plt.figure()
    classes=list(set(dataset_y)) 
    markers=['.',',','o','v','^','<','>','1','2','3','4','8'
             ,'s','p','*','h','H','+','x','D','d','|']
    colors=['b','c','g','k','m','w','r','y']
    for class_id in classes:
        one_class=np.array([feature for (feature,label) in 
                   zip(dataset_X,dataset_y) if label==class_id])
        plt.scatter(one_class[:,0],one_class[:,1],marker=np.random.choice(markers,1)[0],
                    c=np.random.choice(colors,1)[0],label='class_'+str(class_id))
    plt.legend()
    
visual_2D_dataset(dataset_X,dataset_y)
複製程式碼

資料集中資料點的分佈情況

我以前的很多文章都講到了資料集的處理,拆分,準備等,此處的資料集比較簡單,故而簡單講述一下。


2. 用SVM構建線性分類器

你沒有看錯,我就是想用SVM構建一個線性分類器來判別這個資料集。當然,即使是入門級的機器學習攻城獅們,也能看出,這個資料集是一個線性不可分型別,需要用非線性分類器來解決。所以,此處,我就用線性分類器來擬合一下,看看會有什麼樣的“不好”的結果。

# 從資料集的分佈就可以看出,這個資料集不可能用直線分開
# 為了驗證我們的判斷,下面還是使用SVM來構建線性分類器將其分類

# 將整個資料集劃分為train set和test set
from sklearn.model_selection import train_test_split
train_X, test_X, train_y, test_y=train_test_split(
    dataset_X,dataset_y,test_size=0.25,random_state=42)

# print(train_X.shape)  # (225, 2)
# print(train_y.shape)  # (225,)
# print(test_X.shape)  # (75, 2)

# 使用線性核函式初始化一個SVM物件。
from sklearn.svm import SVC
classifier=SVC(kernel='linear') # 構建線性分類器
classifier.fit(train_X,train_y)
複製程式碼

-------------------------------------輸---------出--------------------------------

SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, decision_function_shape='ovr', degree=3, gamma='auto', kernel='linear', max_iter=-1, probability=False, random_state=None, shrinking=True, tol=0.001, verbose=False)

--------------------------------------------完-------------------------------------

然後檢視一下這個訓練好的SVM線性分類器在訓練集和測試集上的表現,如下為在訓練集上的效能報告:

# 模型在訓練集上的效能報告:
from sklearn.metrics import classification_report
plot_classifier(classifier,train_X,train_y)  # 分類器在訓練集上的分類效果
target_names = ['Class-0', 'Class-1']
y_pred=classifier.predict(train_X)
print(classification_report(train_y, y_pred, target_names=target_names))
複製程式碼

SVM線性分類器在訓練集上的分類效果

-------------------------------------輸---------出--------------------------------

precision recall f1-score support

Class-0 0.60 0.96 0.74 114 Class-1 0.89 0.35 0.50 111

avg / total 0.74 0.66 0.62 225

--------------------------------------------完-------------------------------------

很明顯,從分類效果圖和效能報告中,都可以看出這個模型很差,差到姥姥家了。。。 所以,更不用說,在測試集上的表現了,當然是一個差字了得。。。

# 分類器在測試集上的分類效果
plot_classifier(classifier,test_X,test_y)  

target_names = ['Class-0', 'Class-1']
y_pred=classifier.predict(test_X)
print(classification_report(test_y, y_pred, target_names=target_names))
複製程式碼

SVM線性分類器在測試集上的分類效果

-------------------------------------輸---------出--------------------------------

precision recall f1-score support

Class-0 0.57 1.00 0.73 36 Class-1 1.00 0.31 0.47 39

avg / total 0.79 0.64 0.59 75

--------------------------------------------完-------------------------------------


3. 用SVM構建非線性分類器

很明顯,用線性分類器解決不了這個資料集的判別問題,所以我們就上馬非線性分類器吧。

使用SVM構建非線性分類器主要是考慮使用不同的核函式,此處指講述兩種核函式:多項式核函式和徑向基函式。

# 從上面的效能報告中可以看出,分類效果並不好
# 故而我們使用SVM建立非線性分類器,看看其分類效果
# 使用SVM建立非線性分類器主要是使用不同的核函式
# 第一種:使用多項式核函式:
classifier_poly=SVC(kernel='poly',degree=3) # 三次多項式方程
classifier_poly.fit(train_X,train_y)

# 在訓練集上的表現為:
plot_classifier(classifier_poly,train_X,train_y)  

target_names = ['Class-0', 'Class-1']
y_pred=classifier_poly.predict(train_X)
print(classification_report(train_y, y_pred, target_names=target_names))
複製程式碼

SVM多項式核函式的非線性分類器的分類效果

-------------------------------------輸---------出--------------------------------

precision recall f1-score support

Class-0 0.92 0.85 0.89 114 Class-1 0.86 0.93 0.89 111

avg / total 0.89 0.89 0.89 225

--------------------------------------------完-------------------------------------

# 第二種:使用徑向基函式建立非線性分類器
classifier_rbf=SVC(kernel='rbf') 
classifier_rbf.fit(train_X,train_y)

# 在訓練集上的表現為:
plot_classifier(classifier_rbf,train_X,train_y)  

target_names = ['Class-0', 'Class-1']
y_pred=classifier_rbf.predict(train_X)
print(classification_report(train_y, y_pred, target_names=target_names))
複製程式碼

SVM徑向基函式的非線性分類器的分類效果

-------------------------------------輸---------出--------------------------------

precision recall f1-score support

Class-0 0.96 0.96 0.96 114 Class-1 0.96 0.95 0.96 111

avg / total 0.96 0.96 0.96 225

--------------------------------------------完-------------------------------------

########################小**********結###############################

1. 用SVM構建非線性分類器很簡單,只要使用不同的核函式就可以。

2. 對於這個資料集而言,使用了非線性分類器之後,分類效果得到了極大的改善,這個可以從效能報告中看出,而且,很明顯兩種非線性核函式,徑向基函式rbf的分類效果要比多項式核函式的效果更好一些。

3. 這個模型也許還可以繼續優化一些超引數,從而得到更好的分類效果。

#################################################################


注:本部分程式碼已經全部上傳到(我的github)上,歡迎下載。

參考資料:

1, Python機器學習經典例項,Prateek Joshi著,陶俊傑,陳小莉譯

相關文章