這篇部落格整理K均值聚類的內容,包括:
1、K均值聚類的原理;
2、初始類中心的選擇和類別數K的確定;
3、K均值聚類和EM演算法、高斯混合模型的關係。
一、K均值聚類的原理
K均值聚類(K-means)是一種基於中心的聚類演算法,通過迭代,將樣本分到K個類中,使得每個樣本與其所屬類的中心或均值的距離之和最小。
1、定義損失函式
假設我們有一個資料集{x1, x2,..., xN},每個樣本的特徵維度是m維,我們的目標是將資料集劃分為K個類別。假定K的值已經給定,那麼第k個類別的中心定義為μk,k=1,2,..., K,μk是一個m維的特徵向量。我們需要找到每個樣本所屬的類別,以及一組向量{μk},使得每個樣本與它所屬的類別的中心μk的距離平方和最小。
首先,這個距離是什麼距離呢?聚類需要根據樣本之間的相似度,對樣本集合進行劃分,將相似度較高的樣本歸為一類。度量樣本之間相似度的方法包括計算樣本之間的歐氏距離、馬氏距離、餘弦距離或相關係數,而K均值聚類是用歐氏距離的平方來度量樣本之間的相似度。歐式距離的平方公式如下:
把所有樣本與所屬類的中心之間距離的平方之和定義為損失函式:
其中rnk∈{0,1},n=1,2,...,N,k=1,2,...,K,如果rnk=1,那麼表示樣本xn屬於第k類,且對於j≠k,有rnj=0,也就是樣本xn只能屬於一個類別。
於是我們需要找到{rnk}和{μk}的值,使得距離平方之和J最小化。
2、進行迭代
K均值聚類的演算法是一種迭代演算法,每次迭代涉及到兩個連續的步驟,分別對應rnk的最優化和μk的最優化,也對應著EM演算法的E步(求期望)和M步(求極大)兩步。
首先,為μk選擇一些初始值,也就是選擇K個類中心。然後第一步:保持μk固定,選擇rnk來最小化J,也就是把樣本指派到與其最近的中心所屬的類中,得到一個聚類結果;第二步:保持rnk固定,計算μk來最小化J,也就是更新每個類別的中心。不斷重複這兩個步驟直到收斂。
具體來說:
E步:在類中心μk已經確定的情況下,最優化rnk
這一步比較簡單,我們可以對每個樣本xn獨立地進行最優化。將某個樣本分配到第k個類別,如果這個樣本和第k個類別的距離最小,那麼令rnk=1。對N個樣本都這樣進行分配,自然就得到了使所有樣本與類中心的距離平方和最小的{rnk},從而得到了一個聚類結果。
M步:確定了資料集的一種劃分,也就是{rnk}確定後,最優化μk
目標函式J是μk的一個二次函式,令它關於μk的導數為零,就可以使目標函式達到最小值,即
解出μk的結果為:
這個μk就是類別k中所有樣本的均值,所以把這個演算法稱為K均值聚類。
重新為樣本分配類別,再重新計算每個類別的均值,不斷重複這兩個步驟,直到聚類的結果不再改變。
需要注意以下幾點:
①K均值聚類演算法可能收斂到目標函式J的區域性極小值,不能保證收斂到全域性最小值;
②在聚類之前,需要對資料集進行標準化,使得每個樣本的均值為0,標準差為1;
③初始類中心的選擇會直接影響聚類結果,選擇不同的初始類中心,可能會得到不同的聚類結果。
④K均值聚類演算法的複雜度是O(mnk),m是樣本的特徵維度,n是樣本個數,k是類別個數。
二、初始類中心的選擇和類別數K的確定
K均值聚類演算法的思想比較簡單,不涉及到什麼數學知識,關鍵點在於初始類中心的選擇和類別數K的確定,這對聚類的結果有比較大的影響。
(一)初始類中心的選擇
1、第一種方法
用層次聚類演算法進行初始聚類,然後用這些類別的中心作為K均值聚類的初始類中心。層次聚類的複雜度為O(mn3),m是樣本的特徵維度,n是樣本個數,複雜度也是蠻高的,那為什麼用層次聚類的結果作為初始類呢?我想是因為層次聚類的結果完全是由演算法確定的,完全沒有人工的干預,是一個客觀的結果,這樣就把K均值聚類的初始類選擇問題,由主觀確定變成了客觀決定。
2、第二種方法
首先隨機選擇一個點作為第一個初始類中心點,然後計算該點與其他所有樣本點的距離,選擇距離最遠的點作為第二個初始類的中心點,以此類推,直到選出K個初始類中心點。
(二)類別數K的確定
1、輪廓係數
輪廓係數(Silhouette Coefficient)可以用來判定聚類結果的好壞,也可以用來確定類別數K。好的聚類要保證類別內部樣本之間的距離儘可能小(密集度),而類與類之間樣本的距離儘可能大(離散度),輪廓係數就是一個用來度量類的密集度和離散度的綜合指標。
輪廓係數的計算過程和使用如下:
①計算樣本xi到同類Ck其他樣本的平均距離ai,將ai稱為樣本xi的簇內不相似度,ai越小,說明樣本xi越應該被分配到該類。
②計算樣本xi到其他類Cj所有樣本的平均距離bij,j=1, 2 ,..., K,j≠k,稱為樣本xi與類別Cj的不相似度。定義樣本xi的簇間不相似度:bi =min{bi1, bi2, ...,bij,..., biK},j≠k,bi越大,說明樣本xi越不屬於其他簇。
③根據樣本xi的簇內不相似度ai和簇間不相似度bi,定義樣本xi的輪廓係數si,作為樣本xi分類結果的合理性的度量。
④輪廓係數範圍在[-1,1]之間,該值越大,聚類結果越好。si接近1,則樣本xi被分配到類別Ck的結果比較合理;si接近0,說明樣本xi在兩個類的邊界上;si接近-1,說明樣本xi更應該被分配到其他類別。
⑤計算所有樣本的輪廓係數si的均值,得到聚類結果的輪廓係數S,作為聚類結果合理性的度量。輪廓係數越大,聚類結果越好。
⑥使用不同的K值進行K均值聚類,計算各自聚類結果的輪廓係數S,選擇較大的輪廓係數所對應的K值。
2、肘部法則
三、K均值聚類與高斯混合模型的關係
關於K均值聚類與EM演算法、高斯混合模型的關係,主要有以下三點:
1、K均值聚類是一種非概率的聚類演算法,屬於硬聚類方法,也就是一個樣本只能屬於一個類(類與類之間的交集為空)。相比之下,高斯混合模型(GMM)是一種基於概率的聚類演算法,屬於軟聚類方法,每個樣本按照一個概率分佈,屬於多個類。
2、K均值聚類在一次迭代中的兩個步驟,可以看做是EM演算法的E步和M步,而且K均值聚類可以看做是用EM演算法對⾼斯混合模型進行引數估計的⼀個特例,也就是高斯混合模型中分模型的方差σ2相等,為常數,且σ2→0時的極限情況。
3、K均值聚類和基於EM演算法的高斯混合模型,對引數的初始化值比較敏感,由於K均值聚類的計算量遠小於基於EM演算法的高斯混合模型,所以通常運⾏K均值演算法找到⾼斯混合模型的⼀個初始化值,再使用EM演算法進行調節。具體而言,用K均值聚類劃分的K個類別中,各類別中樣本所佔的比例,來初始化K個分模型的權重;用各類別中樣本的均值來初始化K個高斯分佈的期望;用各類別中樣本的方差來初始化K個高斯分佈的方差。
(一)從圖形來理解
為了理解以上這幾點,尤其是第2點,我們可以先從圖形來看。假設高斯混合模型由4個高斯分佈混合而成,高斯分佈的密度函式如下。這裡和《聚類之高斯混合模型與EM演算法》的符號表示一致,y為樣本。
令均值μ=[0,2,4,6],方差σ2=[1,2,3,1],則4個高斯分佈的概率密度函式的圖形如下。我們可以看到,4個圖形之間有重疊的部分,也就說明每個樣本可以按照一個概率分佈αk,屬於多個類,只是屬於某類的概率大些,屬於其他類的概率小些。這表明高斯混合模型是一種軟聚類方法。
然後令均值μ不變,方差σ2=[0.01, 0.01, 0.01, 0.01],也就是4個分模型的方差σk2相等,而且σk2→0,那麼4個高斯分佈的圖形如下。每個高斯分佈的圖形之間沒有交集,那麼每個樣本只能屬於一個類,變成了硬聚類。這也就是高斯混合模型的特例:K均值聚類。
(二)從公式來理解
用EM演算法對高斯混合模型進行極大似然估計,在E步,我們需要基於第i輪迭代的引數θ(i)=(αk, μk, σk)來計算γjk,γjk是第j個樣本yj來自於第k個高斯分佈分模型的概率,k=1,2,...,K。在高斯混合模型中,γj是一個K維的向量,也就是第j個樣本屬於K個類的概率。假設分模型的方差σk2都相等,且是一個常數,不需要再估計,那麼在EM演算法的E步我們計算γjk
考慮σ2→0時的極限情況,如果樣本yj屬於第k類的概率最大,那麼該樣本與第k類的中心點的距離非常近,(yj - μk)2將會趨於0,於是有:
也就是樣本yj屬於第k類的概率近似為1,屬於其他類別的概率近似為0,也就成為了一種硬聚類,也就是K均值聚類。
其實在σ2→0時的極限情況下,最大化高斯混合模型的完全資料的對數似然函式的期望,等價於最小化K均值聚類的目標函式J。
比如有4個高斯分佈,樣本yj的γj為[0.55, 0.15, 0.2, 0.1],那麼屬於第1類的概率γj1最大。而當分模型的方差σ2→0時,樣本yj的γj可能為[0.98, 0.01, 0.05, 0.05],也就是該樣本直接被分配到了第1類,成為了硬聚類。
附:4個高斯分佈的概率密度函式的圖形程式碼
import matplotlib.pyplot as plt import math import numpy as np # 均值 u1 = 0 u2 = 2 u3 = 4 u4 = 6 # 標準差 sig1 = math.sqrt(1) sig2 = math.sqrt(2) sig3 = math.sqrt(3) sig4 = math.sqrt(1) def x(u,sig): return np.linspace(u - 5*sig, u + 5*sig, 100) x1 = x(u1,sig1) x2 = x(u2,sig2) x3 = x(u3,sig3) x4 = x(u4,sig4) # 概率密度 def y(x,u,sig): return np.exp(-(x - u) ** 2 /(2* sig**2))/(math.sqrt(2*math.pi)*sig) y1 = y(x1,u1,sig1) y2 = y(x2,u2,sig2) y3 = y(x3,u3,sig3) y4 = y(x4,u4,sig4) plt.plot(x1,y1, "r-") plt.plot(x2, y2, "g-") plt.plot(x3, y3, "b-") plt.plot(x4, y4, "m-") plt.xticks(range(-6,16,2)) plt.show()
參考資料:
1、李航:《統計學習方法》(第二版)
2、《Pattern Recognition and Machine Learning》
3、https://www.jianshu.com/p/335b376174d4