1.1 資料集簡介
- 來源:archive.ics.uci.edu/ml/datasets…
- 類別:0-9 共10個數字
- 樣本數:1797
- 特徵數:64
- 特徵含義:8x8畫素,每個畫素由0到16之間的整數表示
import numpy as np
from sklearn import datasets
digits = datasets.load_digits()
# 輸出資料集的樣本數與特徵數
print digits.data.shape
# 輸出所有目標類別
print np.unique(digits.target)
# 輸出資料集
print digits.data複製程式碼
(1797, 64)
[0 1 2 3 4 5 6 7 8 9]
[[ 0. 0. 5. ..., 0. 0. 0.]
[ 0. 0. 0. ..., 10. 0. 0.]
[ 0. 0. 0. ..., 16. 9. 0.]
...,
[ 0. 0. 1. ..., 6. 0. 0.]
[ 0. 0. 2. ..., 12. 0. 0.]
[ 0. 0. 10. ..., 12. 1. 0.]]複製程式碼
1.2 資料集視覺化
import matplotlib.pyplot as plt
# 匯入字型管理器,用於提供中文支援
import matplotlib.font_manager as fm
font_set= fm.FontProperties(fname='C:/Windows/Fonts/msyh.ttc', size=14)
# 將影像和目標標籤合併到一個列表中
images_and_labels = list(zip(digits.images, digits.target))
# 列印資料集的前8個影像
plt.figure(figsize=(8, 6))
for index, (image, label) in enumerate(images_and_labels[:8]):
plt.subplot(2, 4, index + 1)
plt.axis('off')
plt.imshow(image, cmap=plt.cm.gray_r,interpolation='nearest')
plt.title(u'訓練樣本:' + str(label), fontproperties=font_set)
plt.show()複製程式碼
# 樣本圖片效果
plt.figure(figsize=(6, 6))
plt.imshow(digits.images[0], cmap=plt.cm.gray_r, interpolation='nearest')
plt.show()複製程式碼
1.3 用 PCA 降維
由於該資料集有 64 個特徵值,也就是說有 64 個維度,因此沒辦法直觀地看到資料的分佈及其之間的關係。但是,實際起作用的維度可能位元徵值的個數要少得多,我們可以通過主成分分析來降低資料集的維度,從而觀察樣本點之間的關係。
主成分分析(PCA):找到兩個變數的線性組合,儘可能保留大部分的資訊,這個新的變數(主成分)就可以替代原來的變數。也就是說,PCA就是通過線性變換來產生新的變數,並最大化保留了資料的差異。
from sklearn.decomposition import *
# 建立一個 PCA 模型
pca = PCA(n_components=2)
# 將資料應用到模型上
reduced_data_pca = pca.fit_transform(digits.data)
# 檢視維度
print reduced_data_pca.shape複製程式碼
(1797, 2)複製程式碼
1.4 繪製散點圖
colors = ['black', 'blue', 'purple', 'yellow', 'white', 'red', 'lime', 'cyan', 'orange', 'gray']
plt.figure(figsize=(8, 6))
for i in range(len(colors)):
x = reduced_data_pca[:, 0][digits.target == i]
y = reduced_data_pca[:, 1][digits.target == i]
plt.scatter(x, y, c=colors[i])
plt.legend(digits.target_names, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.xlabel(u'第一個主成分', fontproperties=font_set)
plt.ylabel(u'第二個主成分', fontproperties=font_set)
plt.title(u"PCA 散點圖", fontproperties=font_set)
plt.show()複製程式碼
2.1 資料歸一化
from sklearn.preprocessing import scale
data = scale(digits.data)
print data複製程式碼
[[ 0. -0.33501649 -0.04308102 ..., -1.14664746 -0.5056698
-0.19600752]
[ 0. -0.33501649 -1.09493684 ..., 0.54856067 -0.5056698
-0.19600752]
[ 0. -0.33501649 -1.09493684 ..., 1.56568555 1.6951369
-0.19600752]
...,
[ 0. -0.33501649 -0.88456568 ..., -0.12952258 -0.5056698
-0.19600752]
[ 0. -0.33501649 -0.67419451 ..., 0.8876023 -0.5056698
-0.19600752]
[ 0. -0.33501649 1.00877481 ..., 0.8876023 -0.26113572
-0.19600752]]複製程式碼
2.2 拆分資料集
將資料集拆分成訓練集和測試集
from sklearn.cross_validation import train_test_split
X_train, X_test, y_train, y_test, images_train, images_test = train_test_split(data, digits.target, digits.images, test_size=0.25, random_state=42)
print "訓練集", X_train.shape
print "測試集", X_test.shape複製程式碼
訓練集 (1347, 64)
測試集 (450, 64)複製程式碼
2.3 使用 SVM 分類器
from sklearn import svm
# 建立 SVC 模型
svc_model = svm.SVC(gamma=0.001, C=100, kernel='linear')
# 將訓練集應用到 SVC 模型上
svc_model.fit(X_train, y_train)
# 評估模型的預測效果
print svc_model.score(X_test, y_test)複製程式碼
0.97777777777777775複製程式碼
2.4 優化引數
svc_model = svm.SVC(gamma=0.001, C=10, kernel='rbf')
svc_model.fit(X_train, y_train)
print svc_model.score(X_test, y_test)複製程式碼
0.98222222222222222複製程式碼
3.1 預測結果
import matplotlib.pyplot as plt
# 使用建立的 SVC 模型對測試集進行預測
predicted = svc_model.predict(X_test)
# 將測試集的影像與預測的標籤合併到一個列表中
images_and_predictions = list(zip(images_test, predicted))
# 列印前 4 個預測的影像和結果
plt.figure(figsize=(8, 2))
for index, (image, prediction) in enumerate(images_and_predictions[:4]):
plt.subplot(1, 4, index + 1)
plt.axis('off')
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
plt.title(u'預測結果: ' + str(prediction), fontproperties=font_set)
plt.show()複製程式碼
3.2 分析結果的準確性
X = np.arange(len(y_test))
# 生成比較列表,如果預測的結果正確,則對應位置為0,錯誤則為1
comp = [0 if y1 == y2 else 1 for y1, y2 in zip(y_test, predicted)]
plt.figure(figsize=(8, 6))
# 影像發生波動的地方,說明預測的結果有誤
plt.plot(X, comp)
plt.ylim(-1, 2)
plt.yticks([])
plt.show()
print "測試集數量:", len(y_test)
print "錯誤識別數:", sum(comp)
print "識別準確率:", 1 - float(sum(comp)) / len(y_test)複製程式碼
測試集數量: 450
錯誤識別數: 8
識別準確率: 0.982222222222複製程式碼
3.3 錯誤識別樣本分析
# 收集錯誤識別的樣本下標
wrong_index = []
for i, value in enumerate(comp):
if value: wrong_index.append(i)
# 輸出錯誤識別的樣本影像
plt.figure(figsize=(8, 6))
for plot_index, image_index in enumerate(wrong_index):
image = images_test[image_index]
plt.subplot(2, 4, plot_index + 1)
plt.axis('off')
plt.imshow(image, cmap=plt.cm.gray_r,interpolation='nearest')
# 影像說明,8->9 表示正確值為8,被錯誤地識別成了9
info = "{right}->{wrong}".format(right=y_test[image_index], wrong=predicted[image_index])
plt.title(info, fontsize=16)
plt.show()複製程式碼
參考文章:Python Machine Learning: Scikit-Learn Tutorial (Article)