前言
本文將繼續講解K-近鄰演算法的專案例項 - 手寫識別系統。
該系統在獲取使用者的手寫輸入後,判斷使用者寫的是什麼。
為了突出核心,簡化細節,本示例系統中的輸入為32x32矩陣,分類結果也均為數字。但對於漢字或者別的分類情形原理都是一樣的。
有了前面學習的基礎,下面直接進入專案開發步驟。
第一步:收集並準備資料
在使用者主目錄的trainingDigits子目錄中,存放的是2000個樣本資料。
每個樣本一個檔案,其中一部分如下所示:
檔案命名格式為:
分類標籤_標籤內序號
如 0_20.txt 就表示該樣本是分類標籤為0的第20個特徵集。20就是個序號以區分標籤內不同檔案而已,沒其他意義。
樣本資料都是32x32矩陣:
對於這樣的二維資料,如何判斷樣本和目標物件的距離呢?首先想到的是可以將二維降到一維。
當然也可以考慮去找找二維的距離求解方法。
下面給出降維函式:
1 # ============================================== 2 # 輸入: 3 # 訓練集檔名(含路徑) 4 # 輸出: 5 # 降維後的樣本資料(這裡一個檔案一份樣本資料) 6 # ============================================== 7 def img2vector(filename): 8 '將32x32的矩陣轉換為1024一維向量' 9 10 # 初始化返回向量 11 returnVect = numpy.zeros((1,1024)) 12 13 # 開啟樣本資料檔案 14 fr = open(filename) 15 16 # 降維處理 17 for i in range(32): 18 lineStr = fr.readline() 19 for j in range(32): 20 returnVect[0,32*i+j] = int(lineStr[j]) 21 22 return returnVect
第二步:測試演算法
K臨近的分類函式程式碼在之前的文章K-近鄰分類演算法原理分析與程式碼實現中給出了,這裡直接呼叫:
# ================================================= # 輸入: # 空 # 輸出: # 對指定的測試集檔案,指定的訓練集資料進行K近鄰分類 # 並列印結果資訊 # ================================================= def handwritingClassTest(): '手寫數字識別系統測試程式碼' # 分類列表 hwLabels = [] # 獲取所有訓練集檔名 trainingFileList = os.listdir('/home/fangmeng/trainingDigits') # 定義訓練集結構體 m = len(trainingFileList) trainingMat = numpy.zeros((m, 1024)) for i in range(m): # 當前訓練集檔名 filenameStr = trainingFileList[i] # 檔名(filenameStr去掉.txt字尾) fileStr = filenameStr.split('.')[0] # 分類標籤 classNumStr = int(fileStr.split('_')[0]) # 將分類標籤加入分類列表 hwLabels.append(classNumStr) # 將當前訓練集檔案降維後加入到訓練集結構體 trainingMat[i] = img2vector('/home/fangmeng/trainingDigits/%s' % filenameStr) # 獲取所有測試集檔名 testFileList = os.listdir('/home/fangmeng/testDigits') # 錯誤分類記數 errorCount = 0 # 測試集檔案個數 mTest = len(testFileList) print "錯誤的分類結果如下:" for i in range(mTest): # 當前測試集檔名 fileNameStr = testFileList[i] # 檔名(filenameStr去掉.txt字尾) fileStr = fileNameStr.split('.')[0] # 分類標籤 classNumStr = int(fileStr.split('_')[0]) # 將當前測試集檔案降維 vectorUnderTest = img2vector('/home/fangmeng/testDigits/%s' % fileNameStr) # 對當前測試檔案進行分類 classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) if (classifierResult != classNumStr): print "分類結果: %d, 實際結果: %d" % (classifierResult, classNumStr) errorCount += 1.0 print "\n總錯誤數: %d" % errorCount print "\n總錯誤數: %f" % (errorCount/float(mTest))
執行結果:
小結
1. K-鄰近演算法的本質是用來分類的,要從分類的思想去思考這個演算法的運用。
2. 再強調一次K-鄰近演算法是沒有訓練過程的,這點和以後學習的其他分類方法,比如決策樹對比後就更清楚了。
3. K-鄰近演算法的效率很低,不論是從時間還是空間上看(單就這個簡單專案都跑得很慢)。因此需要學習更多更優化的演算法。
4. 有興趣有時間可以考慮在hadoop/spark叢集下實現這個專案或使用該演算法的其他類似專案,定能大幅度提升效能。