1 ''' 2 Created on Nov 06, 2017 3 kNN: k Nearest Neighbors 4 5 Input: inX: vector to compare to existing dataset (1xN) 6 dataSet: size m data set of known vectors (NxM) 7 labels: data set labels (1xM vector) 8 k: number of neighbors to use for comparison (should be an odd number) 9 10 Output: the most popular class label 11 12 @author: Liu Chuanfeng 13 ''' 14 import operator 15 import numpy as np 16 import matplotlib.pyplot as plt 17 from os import listdir 18 19 def classify0(inX, dataSet, labels, k): 20 dataSetSize = dataSet.shape[0] 21 diffMat = np.tile(inX, (dataSetSize,1)) - dataSet 22 sqDiffMat = diffMat ** 2 23 sqDistances = sqDiffMat.sum(axis=1) 24 distances = sqDistances ** 0.5 25 sortedDistIndicies = distances.argsort() 26 classCount = {} 27 for i in range(k): 28 voteIlabel = labels[sortedDistIndicies[i]] 29 classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 30 sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True) 31 return sortedClassCount[0][0] 32 33 #資料預處理,將檔案中資料轉換為矩陣型別 34 def file2matrix(filename): 35 fr = open(filename) 36 arrayLines = fr.readlines() 37 numberOfLines = len(arrayLines) 38 returnMat = np.zeros((numberOfLines, 3)) 39 classLabelVector = [] 40 index = 0 41 for line in arrayLines: 42 line = line.strip() 43 listFromLine = line.split('\t') 44 returnMat[index,:] = listFromLine[0:3] 45 classLabelVector.append(int(listFromLine[-1])) 46 index += 1 47 return returnMat, classLabelVector 48 49 #資料歸一化處理:由於矩陣各列資料取值範圍的巨大差異導致各列對計算結果的影響大小不一,需要歸一化以保證相同的影響權重 50 def autoNorm(dataSet): 51 maxVals = dataSet.max(0) 52 minVals = dataSet.min(0) 53 ranges = maxVals - minVals 54 m = dataSet.shape[0] 55 normDataSet = (dataSet - np.tile(minVals, (m, 1))) / np.tile(ranges, (m, 1)) 56 return normDataSet, ranges, minVals 57 58 #約會網站測試程式碼 59 def datingClassTest(): 60 hoRatio = 0.10 61 datingDataMat, datingLabels = file2matrix('datingTestSet2.txt') 62 normMat, ranges, minVals = autoNorm(datingDataMat) 63 m = normMat.shape[0] 64 numTestVecs = int(m * hoRatio) 65 errorCount = 0.0 66 for i in range(numTestVecs): 67 classifyResult = classify0(normMat[i,:], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3) 68 print('theclassifier came back with: %d, the real answer is: %d' % (classifyResult, datingLabels[i])) 69 if ( classifyResult != datingLabels[i]): 70 errorCount += 1.0 71 print ('the total error rate is: %.1f%%' % (errorCount/float(numTestVecs) * 100)) 72 73 #約會網站預測函式 74 def classifyPerson(): 75 resultList = ['not at all', 'in small doses', 'in large doses'] 76 percentTats = float(input("percentage of time spent playing video games?")) 77 ffMiles = float(input("frequent flier miles earned per year?")) 78 iceCream = float(input("liters of ice cream consumed per year?")) 79 datingDataMat, datingLabels = file2matrix('datingTestSet2.txt') 80 normMat, ranges, minVals = autoNorm(datingDataMat) 81 inArr = np.array([ffMiles, percentTats, iceCream]) 82 classifyResult = classify0((inArr-minVals)/ranges, normMat, datingLabels, 3) 83 print ("You will probably like this persoon:", resultList[classifyResult - 1]) 84 85 86 #手寫識別系統#============================================================================================================ 87 #資料預處理:輸入圖片為32*32的文字型別,將其形狀轉換為1*1024 88 def img2vector(filename): 89 returnVect = np.zeros((1, 1024)) 90 fr = open(filename) 91 for i in range(32): 92 lineStr = fr.readline() 93 for j in range(32): 94 returnVect[0, 32*i+j] = int(lineStr[j]) 95 return returnVect 96 97 #手寫數字識別系統測試程式碼 98 def handwritingClassTest(): 99 hwLabels = [] 100 trainingFileList = listdir('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\traingDigits') 101 m = len(trainingFileList) 102 trainingMat = np.zeros((m, 1024)) 103 for i in range(m): #| 104 fileNameStr = trainingFileList[i] #| 105 fileName = fileNameStr.split('.')[0] #| 獲取訓練集路徑下每一個檔案,分割檔名,將第一個數字作為標籤儲存在hwLabels中 106 classNumber = int(fileName.split('_')[0]) #| 107 hwLabels.append(classNumber) #| 108 trainingMat[i,:] = img2vector('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\traingDigits\\%s' % fileNameStr) #變換矩陣形狀: from 32*32 to 1*1024 109 testFileList = listdir('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\testDigits') 110 errorCount = 0.0 111 mTest = len(testFileList) 112 for i in range(mTest): #同訓練集 113 fileNameStr = testFileList[i] 114 fileName = fileNameStr.split('.')[0] 115 classNumber = int(fileName.split('_')[0]) 116 vectorUnderTest = img2vector('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\testDigits\\%s' % fileNameStr) 117 classifyResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) #計算歐氏距離並分類,返回計算結果 118 print ('The classifier came back with: %d, the real answer is: %d' % (classifyResult, classNumber)) 119 if (classifyResult != classNumber): 120 errorCount += 1.0 121 print ('The total number of errors is: %d' % (errorCount)) 122 print ('The total error rate is: %.1f%%' % (errorCount/float(mTest) * 100)) 123 124 # Simple unit test of func: file2matrix() 125 #datingDataMat, datingLabels = file2matrix('datingTestSet2.txt') 126 #print (datingDataMat) 127 #print (datingLabels) 128 129 # Usage of figure construction of matplotlib 130 #fig=plt.figure() 131 #ax = fig.add_subplot(111) 132 #ax.scatter(datingDataMat[:,1], datingDataMat[:,2], 15.0*np.array(datingLabels), 15.0*np.array(datingLabels)) 133 #plt.show() 134 135 # Simple unit test of func: autoNorm() 136 #normMat, ranges, minVals = autoNorm(datingDataMat) 137 #print (normMat) 138 #print (ranges) 139 #print (minVals) 140 141 # Simple unit test of func: img2vector 142 #testVect = img2vector('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\testDigits\\0_13.txt') 143 #print (testVect[0, 32:63] ) 144 145 #約會網站測試 146 datingClassTest() 147 148 #約會網站預測 149 classifyPerson() 150 151 #手寫數字識別系統預測 152 handwritingClassTest()
theclassifier came back with: 3, the real answer is: 3
the total error rate is: 0.0%
theclassifier came back with: 2, the real answer is: 2
the total error rate is: 0.0%
theclassifier came back with: 1, the real answer is: 1
the total error rate is: 0.0%
theclassifier came back with: 2, the real answer is: 2
the total error rate is: 4.0%
theclassifier came back with: 1, the real answer is: 1
the total error rate is: 4.0%
theclassifier came back with: 3, the real answer is: 1
the total error rate is: 5.0%
percentage of time spent playing video games?10
frequent flier miles earned per year?10000
liters of ice cream consumed per year?0.5
You will probably like this persoon: in small doses
The classifier came back with: 9, the real answer is: 9
The total number of errors is: 27
The total error rate is: 6.8%