機器學習經典演算法之K-Means

程式設計師姜小白發表於2019-07-01

一、簡介

K-Means 是一種非監督學習,解決的是聚類問題。K 代表的是 K 類,Means 代表的是中心,你可以理解這個演算法的本質是確定 K 類的中心點,當你找到了這些中心點,也就完成了聚類。

/*請尊重作者勞動成果,轉載請標明原文連結:*/

/* https://www.cnblogs.com/jpcflyer/p/11117012.html * /

先請你和我思考一個場景,假設我有 20 支亞洲足球隊,想要將它們按照成績劃分成 3 個等級,可以怎樣劃分?

 

二、 K-Means 的工作原理

對亞洲足球隊的水平,你可能也有自己的判斷。比如一流的亞洲球隊有誰?你可能會說伊朗或韓國。二流的亞洲球隊呢?你可能說是中國。三流的亞洲球隊呢?你可能會說越南。

其實這些都是靠我們的經驗來劃分的,那麼伊朗、中國、越南可以說是三個等級的典型代表,也就是我們每個類的中心點。

所以回過頭來,如何確定 K 類的中心點?一開始我們是可以隨機指派的,當你確認了中心點後,就可以按照距離將其他足球隊劃分到不同的類別中。

這也就是 K-Means 的中心思想,就是這麼簡單直接。你可能會問:如果一開始,選擇一流球隊是中國,二流球隊是伊朗,三流球隊是韓國,中心點選擇錯了怎麼辦?其實不用擔心,K-Means 有自我糾正機制,在不斷的迭代過程中,會糾正中心點。中心點在整個迭代過程中,並不是唯一的,只是你需要一個初始值,一般演算法會隨機設定初始的中心點。

好了,那我來把 K-Means 的工作原理給你總結下:

選取 K 個點作為初始的類中心點,這些點一般都是從資料集中隨機抽取的;

將每個點分配到最近的類中心點,這樣就形成了 K 個類,然後重新計算每個類的中心點;

重複第二步,直到類不發生變化,或者你也可以設定最大迭代次數,這樣即使類中心點發生變化,但是隻要達到最大迭代次數就會結束。

 

三、 如何給亞洲球隊做聚類

對於機器來說需要資料才能判斷類中心點,所以我整理了 2015-2019 年亞洲球隊的排名,如下表所示。

我來說明一下資料概況。

其中 2019 年國際足聯的世界排名,2015 年亞洲盃排名均為實際排名。2018 年世界盃中,很多球隊沒有進入到決賽圈,所以只有進入到決賽圈的球隊才有實際的排名。如果是亞洲區預選賽 12 強的球隊,排名會設定為 40。如果沒有進入亞洲區預選賽 12 強,球隊排名會設定為 50。

針對上面的排名,我們首先需要做的是資料規範化。 我先把數值都規範化到 [0,1] 的空間中,得到了以下的數值表:

如果我們隨機選取中國、日本、韓國為三個類的中心點,我們就需要看下這些球隊到中心點的距離。

距離有多種計算的方式,有關距離的計算我在 KNN 演算法中也講到過。 歐氏距離是最常用的距離計算方式,這裡我選擇歐氏距離作為距離的標準,計算每個隊伍分別到中國、日本、韓國的距離,然後根據距離遠近來劃分。我們看到大部分的隊,會和中國隊聚類到一起。這裡我整理了距離的計算過程,比如中國和中國的歐氏距離為 0,中國和日本的歐式距離為 0.732003。如果按照中國、日本、韓國為 3 個分類的中心點,歐氏距離的計算結果如下表所示:

然後我們再重新計算這三個類的中心點,如何計算呢?最簡單的方式就是取平均值,然後根據新的中心點按照距離遠近重新分配球隊的分類,再根據球隊的分類更新中心點的位置。計算過程這裡不展開,最後一直迭代(重複上述的計算過程:計算中心點和劃分分類)到分類不再發生變化,可以得到以下的分類結果:

所以我們能看出來第一梯隊有日本、韓國、伊朗、沙特、澳洲;第二梯隊有中國、伊拉克、阿聯酋、烏茲別克;第三梯隊有卡達、泰國、越南、阿曼、巴林、朝鮮、印尼、敘利亞、約旦、科威特和巴勒斯坦。

 

四、 如何使用 sklearn 中的 K-Means 演算法

sklearn 是 Python 的機器學習工具庫,如果從功能上來劃分,sklearn 可以實現分類、聚類、迴歸、降維、模型選擇和預處理等功能。這裡我們使用的是 sklearn 的聚類函式庫,因此需要引用工具包,具體程式碼如下:

1 from sklearn.cluster import KMeans

當然 K-Means 只是 sklearn.cluster 中的一個聚類庫,實際上包括 K-Means 在內,sklearn.cluster 一共提供了 9 種聚類方法,比如 Mean-shift,DBSCAN,Spectral clustering(譜聚類)等。這些聚類方法的原理和 K-Means 不同,這裡不做介紹。

我們看下 K-Means 如何建立:

1 KMeans(n_clusters=8, init='k-means++', n_init=10, max_iter=300, tol=0.0001, precompute_distances='auto', verbose=0, random_state=None, copy_x=True, n_jobs=1, algorithm='auto')

我們能看到在 K-Means 類建立的過程中,有一些主要的引數:

n_clusters : 即 K 值,一般需要多試一些 K 值來保證更好的聚類效果。你可以隨機設定一些 K 值,然後選擇聚類效果最好的作為最終的 K 值;

max_iter : 最大迭代次數,如果聚類很難收斂的話,設定最大迭代次數可以讓我們及時得到反饋結果,否則程式執行時間會非常長;

n_init :初始化中心點的運算次數,預設是 10。程式是否能快速收斂和中心點的選擇關係非常大,所以在中心點選擇上多花一些時間,來爭取整體時間上的快速收斂還是非常值得的。由於每一次中心點都是隨機生成的,這樣得到的結果就有好有壞,非常不確定,所以要執行 n_init 次, 取其中最好的作為初始的中心點。如果 K 值比較大的時候,你可以適當增大 n_init 這個值;

algorithm :k-means 的實現演算法,有“auto” “full”“elkan”三種。一般來說建議直接用預設的"auto"。簡單說下這三個取值的區別,如果你選擇"full"採用的是傳統的 K-Means 演算法,“auto”會根據資料的特點自動選擇是選擇“full”還是“elkan”。我們一般選擇預設的取值,即“auto” 。

 

在建立好 K-Means 類之後,就可以使用它的方法,最常用的是 fit 和 predict 這個兩個函式。你可以單獨使用 fit 函式和 predict 函式,也可以合併使用 fit_predict 函式。其中 fit(data) 可以對 data 資料進行 k-Means 聚類。 predict(data) 可以針對 data 中的每個樣本,計算最近的類。

現在我們要完整地跑一遍 20 支亞洲球隊的聚類問題。

 1 # coding: utf-8
 2 
 3 from sklearn.cluster import KMeans
 4 
 5 from sklearn import preprocessing
 6 
 7 import pandas as pd
 8 
 9 import numpy as np
10 
11 # 輸入資料
12 
13 data = pd.read_csv('data.csv', encoding='gbk')
14 
15 train_x = data[["2019 年國際排名 ","2018 世界盃 ","2015 亞洲盃 "]]
16 
17 df = pd.DataFrame(train_x)
18 
19 kmeans = KMeans(n_clusters=3)
20 
21 # 規範化到 [0,1] 空間
22 
23 min_max_scaler=preprocessing.MinMaxScaler()
24 
25 train_x=min_max_scaler.fit_transform(train_x)
26 
27 # kmeans 演算法
28 
29 kmeans.fit(train_x)
30 
31 predict_y = kmeans.predict(train_x)
32 
33 # 合併聚類結果,插入到原資料中
34 
35 result = pd.concat((data,pd.DataFrame(predict_y)),axis=1)
36 
37 result.rename({0:u'聚類'},axis=1,inplace=True)
38 
39 print(result)

執行結果:

 1 國家  2019 年國際排名  2018 世界盃  2015 亞洲盃  聚類
 2 
 3 中國         73       40        7   2
 4 
 5 日本         60       15        5   0
 6 
 7 韓國         61       19        2   0
 8 
 9 伊朗         34       18        6   0
10 
11 沙特         67       26       10   0
12 
13 伊拉克         91       40        4   2
14 
15 卡達        101       40       13   1
16 
17 阿聯酋         81       40        6   2
18 
19 烏茲別克         88       40        8   2
20 
21 泰國        122       40       17   1
22 
23 越南        102       50       17   1
24 
25  阿曼         87       50       12   1
26 
27 巴林        116       50       11   1
28 
29 朝鮮        110       50       14   1
30 
31 印尼        164       50       17   1
32 
33 澳洲         40       30        1   0
34 
35 敘利亞         76       40       17   1
36 
37 約旦        118       50        9   1
38 
39  科威特        160       50       15   1
40 
41 巴勒斯坦         96       50       16   1

搜尋關注微信公眾號“程式設計師姜小白”,獲取更新精彩內容哦。

 

相關文章