educoder 機器學習 --- kNN演算法

kafuuchino發表於2024-07-01

第一關:

#encoding=utf8
import numpy as np

from collections import Counter

class kNNClassifier(object):
    def __init__(self, k):
        '''
        初始化函式
        :param k:kNN演算法中的k
        '''
        self.k = k
        # 用來存放訓練資料,型別為ndarray
        self.train_feature = None
        # 用來存放訓練標籤,型別為ndarray
        self.train_label = None


    def fit(self, feature, label):
        '''
        kNN演算法的訓練過程
        :param feature: 訓練集資料,型別為ndarray
        :param label: 訓練集標籤,型別為ndarray
        :return: 無返回
        '''

        #********* Begin *********#
        self.train_feature = feature
        self.train_label = label
        #********* End *********#


    def predict(self, feature):
        '''
        kNN演算法的預測過程
        :param feature: 測試集資料,型別為ndarray
        :return: 預測結果,型別為ndarray或list
        '''

        #********* Begin *********#
        result = []
        for data in feature:
            dist = np.sqrt(np.sum((self.train_feature - data) ** 2, axis = 1)) # 歐氏距離
            neighbor = np.argsort(dist)[0 : self.k]
            kLabel = (self.train_label[i] for i in neighbor)
            key, value = Counter(kLabel).most_common(1)[0] # 如果k個鄰居中出現次數最多的label不止一個,要取總距離最小的label,這裡直接取第一個(懶得寫了
            result.append(key)
        return result
        #********* End *********#

第2關:

from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler

def classification(train_feature, train_label, test_feature):
    '''
    對test_feature進行紅酒分類
    :param train_feature: 訓練集資料,型別為ndarray
    :param train_label: 訓練集標籤,型別為ndarray
    :param test_feature: 測試集資料,型別為ndarray
    :return: 測試集資料的分類結果
    '''

    #********* Begin *********#
    #例項化StandardScaler函式
    scaler = StandardScaler()
    train_feature = scaler.fit_transform(train_feature)
    test_feature = scaler.transform(test_feature)
   
    #生成K近鄰分類器
    clf = KNeighborsClassifier()
    #訓練分類器
    clf.fit(train_feature, train_label)
    #進行預測
    predict_result = clf.predict(test_feature)
    return predict_result 
    #********* End **********#

相關文章