K - 近鄰演算法

小潘發表於2020-12-19

K-近鄰演算法(KNN)原理

K Nearest Neighbor演算法又叫KNN演算法,這個演算法是機器學習裡面一個比較經典的演算法,總體來說KNN演算法是相對比較容易理解的演算法

定義

如果一個樣本在特徵空間中的k個最相似(即特徵空間中最鄰近)的樣本中的大多數屬於某一個類別,則該樣本也屬於這個類別。
距離公式
兩個樣本的距離

  • 曼哈頓距離(絕對值距離)
  • 歐氏距離
  • 明可夫斯基距離

K-近鄰演算法 API

  • sklearn.neighbors.KNeighborsClassifier(n_neighbors=5,algorithm='auto')
    • n_neighborsint,可選(預設=5),k_neighbors查詢預設使用的鄰居數。
    • algorithm:{'auto','ball_tree','kd_tree','brute'},可選用於計算最近鄰居的演算法:'ball_tree'將會使用BallTree,'kd_tree'將使用KDTree'auto'將嘗試根據傳遞給fit方法的值來決定最合適的演算法。(不同實現方式影響效率)

案例:鳶尾花種類預測

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier

if __name__ == '__main__':
    #  獲取資料
    iris = load_iris()

    # 劃分資料集
    x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=6)

    # 特徵工程:標準化
    transfer = StandardScaler()
    x_train = transfer.fit_transform(x_train)
    x_test = transfer.fit_transform(x_test)     # 控制變數, 用同樣的引數進行標準化

    # KNN 演算法預估器
    estimator = KNeighborsClassifier(n_neighbors=3)
    estimator.fit(x_train, y_train)

    # 模型評估
    # 方法一:直接對比真實值和預測值
    y_predict = estimator.predict(x_test)
    print('y_predict:\n', y_predict)
    print('直接對比真實值和預測值:\n', y_test == y_predict)

    # 方法二:計算準確率
    score = estimator.score(x_test, y_test)
    print('準確率:\n', score)

結果分析

  • k 值取多大?有什麼影響?
    • k值取很小:容易受到異常點的影響
    • k值取很大:受到樣本均衡的問題
  • 效能問題?
    • 距離計算上面,時間複雜度高

K-近鄰總結

  • 優點:
    • 簡單,易於理解,易於實現,無需訓
  • 缺點:
    • 懶情演算法,對測試樣本分類時的量大, 記憶體開銷大
    • 必須指定K值,K值選擇不當則分類精度不能保證
  • 使用場景:小資料場景,幾千~幾萬樣本,具體場景具體業務去測試

相關文章