第三篇:基於K-近鄰分類演算法的手寫識別系統

穆晨發表於2017-01-19

前言

       本文將繼續講解K-近鄰演算法的專案例項 - 手寫識別系統。

       該系統在獲取使用者的手寫輸入後,判斷使用者寫的是什麼。

       為了突出核心,簡化細節,本示例系統中的輸入為32x32矩陣,分類結果也均為數字。但對於漢字或者別的分類情形原理都是一樣的。

       有了前面學習的基礎,下面直接進入專案開發步驟。

第一步:收集並準備資料

       在使用者主目錄的trainingDigits子目錄中,存放的是2000個樣本資料。

       每個樣本一個檔案,其中一部分如下所示:

      

       檔案命名格式為:

              分類標籤_標籤內序號

       如 0_20.txt 就表示該樣本是分類標籤為0的第20個特徵集。20就是個序號以區分標籤內不同檔案而已,沒其他意義。

       樣本資料都是32x32矩陣:

      

       對於這樣的二維資料,如何判斷樣本和目標物件的距離呢?首先想到的是可以將二維降到一維。

       當然也可以考慮去找找二維的距離求解方法。

       下面給出降維函式:

 1 # ==============================================
 2 # 輸入:
 3 #        訓練集檔名(含路徑)
 4 # 輸出:
 5 #        降維後的樣本資料(這裡一個檔案一份樣本資料)
 6 # ==============================================
 7 def img2vector(filename):
 8     '將32x32的矩陣轉換為1024一維向量'
 9     
10     # 初始化返回向量
11     returnVect = numpy.zeros((1,1024))
12     
13     # 開啟樣本資料檔案
14     fr = open(filename)
15     
16     # 降維處理
17     for i in range(32):
18         lineStr = fr.readline()
19         for j in range(32):
20             returnVect[0,32*i+j] = int(lineStr[j])
21             
22     return returnVect

第二步:測試演算法

       K臨近的分類函式程式碼在之前的文章K-近鄰分類演算法原理分析與程式碼實現中給出了,這裡直接呼叫:

# =================================================
# 輸入:
#
# 輸出:
#        對指定的測試集檔案,指定的訓練集資料進行K近鄰分類
#        並列印結果資訊
# =================================================
def handwritingClassTest():
    '手寫數字識別系統測試程式碼'
    
    # 分類列表
    hwLabels = []
    
    # 獲取所有訓練集檔名
    trainingFileList = os.listdir('/home/fangmeng/trainingDigits')
    
    # 定義訓練集結構體
    m = len(trainingFileList)
    trainingMat = numpy.zeros((m, 1024))
    
    for i in range(m):
        # 當前訓練集檔名
        filenameStr = trainingFileList[i]
        # 檔名(filenameStr去掉.txt字尾)
        fileStr = filenameStr.split('.')[0]
        # 分類標籤
        classNumStr = int(fileStr.split('_')[0])
        # 將分類標籤加入分類列表
        hwLabels.append(classNumStr)
        # 將當前訓練集檔案降維後加入到訓練集結構體
        trainingMat[i] = img2vector('/home/fangmeng/trainingDigits/%s' % filenameStr)
    
    # 獲取所有測試集檔名
    testFileList = os.listdir('/home/fangmeng/testDigits')
    # 錯誤分類記數
    errorCount = 0
    # 測試集檔案個數
    mTest = len(testFileList)
    
    print "錯誤的分類結果如下:"
    for i in range(mTest):
        # 當前測試集檔名
        fileNameStr = testFileList[i]
        # 檔名(filenameStr去掉.txt字尾)
        fileStr = fileNameStr.split('.')[0]
        # 分類標籤
        classNumStr = int(fileStr.split('_')[0])
        # 將當前測試集檔案降維
        vectorUnderTest = img2vector('/home/fangmeng/testDigits/%s' % fileNameStr)
        # 對當前測試檔案進行分類
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
        
        if (classifierResult != classNumStr): 
            print "分類結果: %d, 實際結果: %d" % (classifierResult, classNumStr)
            errorCount += 1.0
            
    print "\n總錯誤數: %d" % errorCount
    print "\n總錯誤數: %f" % (errorCount/float(mTest))

       執行結果:

      

小結

1. K-鄰近演算法的本質是用來分類的,要從分類的思想去思考這個演算法的運用。

2. 再強調一次K-鄰近演算法是沒有訓練過程的,這點和以後學習的其他分類方法,比如決策樹對比後就更清楚了。

3. K-鄰近演算法的效率很低,不論是從時間還是空間上看(單就這個簡單專案都跑得很慢)。因此需要學習更多更優化的演算法。

4. 有興趣有時間可以考慮在hadoop/spark叢集下實現這個專案或使用該演算法的其他類似專案,定能大幅度提升效能。

相關文章