一.瞭解什麼是K-NN演算法
1.KNN演算法原理
KNN(K-Nearest Neighbor)演算法是機器學習演算法中最基礎、最簡單的演算法之一。它既能用於分類,也能用於迴歸。KNN透過測量不同特徵值之間的距離來進行分類。
KNN演算法的思想非常簡單:對於任意n維輸入向量,分別對應於特徵空間中的一個點,輸出為該特徵向量所對應的類別標籤或預測值。
KNN演算法是一種非常特別的機器學習演算法,因為它沒有一般意義上的學習過程。它的工作原理是利用訓練資料對特徵向量空間進行劃分,並將劃分結果作為最終演算法模型。存在一個樣本資料集合,也稱作訓練樣本集,並且樣本集中的每個資料都存在標籤,即我們知道樣本集中每一資料與所屬分類的對應關係。
輸入沒有標籤的資料後,將這個沒有標籤的資料的每個特徵與樣本集中的資料對應的特徵進行比較,然後提取樣本中特徵最相近的資料(最近鄰)的分類標籤。
2.KNN演算法三要素:k值的選取,距離度量的方式和分類決策規則。
3.K值的選擇方法:
一種常見的方法是從k等於訓練集中樣本數量的平方根開始。比如訓練集中有100個案例,則k可以從10開始進行進行進一步篩選。
另一種方法是基於各種測試資料測試多個k值,並選擇一個可以提供最好分類效能的k值。除非資料的噪聲非常大,否則大的訓練集可以使k值的選擇不那麼重要。
還有一種方法是選擇一個較大的k值,同時用一個權重投票,在這個過程中,認為較近鄰的投票比遠的投票權重更大。
4.距離度量的方式
距離度量的方式有三種:歐式距離、曼哈頓距離、閔可夫斯基距離。
海倫一直使用線上約會網站尋找適合自己的約會物件。她曾交往過三種型別的人:
-
- 不喜歡的人
- 一般喜歡的人
- 非常喜歡的人
這些人包含以下三種特徵
-
- 每年獲得的飛行常客里程數
- 玩影片遊戲所耗時間百分比
- 每週消費的冰淇淋公升數
該網站現在需要儘可能向海倫推薦她喜歡的人,需要我們設計一個分類器,根據使用者的以上三種特徵,識別出是否該向海倫推薦。
from numpy import * import operator def createDataSet(): group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]]) labels = ['A', 'A', 'B', 'B'] return group, labels group, labels = createDataSet() print(group) print(labels) import matplotlib.pyplot as plt def createDataSet(): group = np.array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]]) labels = ['A', 'A', 'B', 'B'] return group, labels def plotDataSet(group, labels): label_dict = {'A': 'red', 'B': 'blue'} colors = [label_dict[label] for label in labels] fig, ax = plt.subplots() ax.scatter(group[:, 0], group[:, 1], c=colors) ax.set_xlabel('x') ax.set_ylabel('y') ax.set_title('Data Set') plt.show() group, labels = createDataSet() plotDataSet(group, labels)
def classify0(inX, dataSet, labels, k): #numpy函式shape[0]返回dataSet的行數 dataSetSize = dataSet.shape[0] #在列向量方向上重複inX共1次(橫向),行向量方向上重複inX共dataSetSize次(縱向) diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet #二維特徵相減後平方 sqDiffMat = diffMat**2 #sum()所有元素相加,sum(0)列相加,sum(1)行相加 sqDistances = sqDiffMat.sum(axis=1) #開方,計算出距離 distances = sqDistances**0.5 #返回distances中元素從小到大排序後的索引值 sortedDistIndices = distances.argsort() #定一個記錄類別次數的字典 classCount = {} for i in range(k): #取出前k個元素的類別 voteIlabel = labels[sortedDistIndices[i]] #dict.get(key,default=None),字典的get()方法,返回指定鍵的值,如果值不在字典中返回預設值。 #計算類別次數 classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) # import pdb # pdb.set_trace() #返回次數最多的類別,即所要分類的類別 return sortedClassCount[0][0]
三.總結
經過對KNN演算法的學習後,我學會了一種用於分類和迴歸的統計方法,KNN演算法是一種簡單而直觀的分類演算法,其基本思想是透過找到與待分類樣本最近的K個已知類別樣本來確定其類別。其主要優點是實現簡單、可解釋性強,並且適用於多分類問題和非線性分類問題。但是,KNN演算法的缺點也很明顯,主要包括計算複雜度高、需要大量的儲存空間等。世界上沒有完美的演算法,每個演算法都有優有劣,但KNN演算法較為簡單,對於我們剛入門的新手來說比較適合,KNN演算法雖然有一些侷限性和缺點,但是在實際應用中仍然具有很大的價值和意義。理解KNN演算法的基本思想,掌握其實現方法和最佳化技巧,可以對我們進一步學習和應用其他機器學習演算法提供很好的基礎。
參考文獻:KNN演算法原理-CSDN部落格