KMeans演算法與GMM混合高斯聚類

LeonYi發表於2023-04-16

一、K-Means

K-Means是GMM的特例(硬聚類,基於原型的聚類)。假設多元高斯分佈的協方差為0,方差相同。
 
K-Means演算法思想
對於給定的樣本集,按照樣本之間的距離大小,將樣本集劃分為K個簇。讓簇內的點儘量緊密的連在一起,而讓簇間的距離儘量的大。
0
N個d維樣本,時間複雜度 O(kLNd)
  1. 初始K個類(簇心)
  2. E步:對每個樣本,計算到K個類的歐式距離,並分配類標籤 O(kNd)
  3. M步:基於類內的樣本,以樣本均值更新類(均值最小化,類到類內樣本的誤差) O(Nd)
  4. 重複2-3步,直到聚類結果不變化或收斂
迭代次數為L
 
收斂性證明:
 
聚類處理:
特徵歸一化,缺失值,異常值
 
K-Means的主要優點有:
  1)基於原型的聚類,實現簡單收斂速度快。
  2)聚類效果較優。
  3)演算法的可解釋度比較強。
  4)主要需要調參的引數僅僅是簇數k。
K-Means的主要缺點有:
  1)K值的選取不好把握
  2)對於不是凸的資料集比較難收斂
  3)如果各隱含類別的資料不平衡,比如各隱含類別的資料量嚴重失衡,或者各隱含類別的方差不同,則聚類效果不佳。
  4) 採用迭代方法,得到的結果只是區域性最優(本身是個NP-hard問題,組合最佳化,多項式係數)
  5) 對噪音和異常點比較的敏感。
 
# 基於Cursor生成的程式碼
import numpy as np

def k_means(X, k, max_iters=100):
    # randomly initialize centroids
    centroids = X[np.random.choice(range(len(X)), k, replace=False)]
    
    for i in range(max_iters):
        # calculate distances between each point and each centroid
        distances = np.sqrt(((X - centroids[:, np.newaxis])**2).sum(axis=2))
        
        # assign each point to the closest centroid
        labels = np.argmin(distances, axis=0)
        
        # update centroids to be the mean of the points assigned to them
        for j in range(k):
            centroids[j] = X[labels == j].mean(axis=0)
    
    return centroids, labels

d = 3
k = 3
X = np.random.rand(100, 3)
centroids, labels = k_means(X, k, max_iters=100)

import matplotlib.pyplot as plt

fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')

ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=labels, cmap='viridis')
ax.scatter(centroids[:, 0], centroids[:, 1], centroids[:, 2], marker='*', s=300, c='r')

ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')

plt.show()

二、GMM

⾼斯分佈的線性組合可以給出相當複雜的機率密度形式。
透過使⽤⾜夠多的⾼斯分佈,並且調節它們的均值和⽅差以及線性組合的係數,⼏乎所有的連續機率密度都能夠以任意的精度近似。
0
對3個高斯分佈的機率密度函式進行加權。考慮K個⾼斯機率密度的疊加,形式為:
0
0
混合⾼斯(mixture of Gaussians),每⼀個⾼斯機率密度N (x | µk, Σk)被稱為混合分佈的⼀個成分(component),並且有⾃⼰的均值µk和協⽅差Σk。
0
具有3個成分的混合⾼斯分佈的輪廓線。引數πk被稱為混合係數。GMM
 
可把πk = p(k)看成選擇第k個成分的先驗機率, 把 密度N (x | µk, Σk) = p(x | k)看成以k為條件的x的機率。
⾼斯混合分佈的形式由引數π, µ和Σ控制,其中令π ≡ {π1, . . . , πK}, µ ≡
{µ1, . . . , µK}且Σ ≡ {Σ1, . . . , Σk}。⼀種確定這些引數值的⽅法是使⽤最⼤似然法。根據公式),對數似然函式為:
0
因為對數中存在⼀個求和式,導致引數的最⼤似然解不再有⼀個封閉形式的解析解:
  • ⼀種最⼤化這個似然函式的⽅法是使⽤迭代數值最佳化⽅法。
  • 另⼀種是使⽤EM期望最⼤化演算法(對包含隱變數的似然進行迭代最佳化)。
 
樣本x為觀測資料,混合係數為隱變數,高斯分佈的引數。
當成分為多元高斯分佈時(d維),相當於從混合多元高斯分佈中生成了樣本,透過EM演算法迭代地學習模型引數(均值和方差以及混合係數)。
  1. 期望:根據引數,更新樣本關於類的響應度(隸屬度,相當於分別和K個類計算距離並歸一化)。確定響應度,就可以確定EM演算法的Q函式(完全資料的對數似然關於 分佈的期望),原始似然的下界。
  2. 最大化:根據響應度,計算均值、方差。
EM演算法收斂後,直接求每個樣本關於成分的響應度即可得到聚類結果(可軟,可硬argmax)
 
當多元高斯分佈的方差相同時,且每個樣本只能指定給一個類時(one-hot響應度,argmax),GMM退化成K-means演算法。
0
import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
from sklearn.cluster import KMeans

# 建立資料,並視覺化
X, y = datasets.make_blobs(n_samples=1500,
                             cluster_std=[1.0, 2.5, 0.5],
                             random_state=170)
plt.figure(figsize=(12,4))
plt.rcParams['font.family'] = 'STKaiti'
plt.rcParams['font.size'] = 20
plt.subplot(1,3,1)
plt.scatter(X[:,0],X[:,1],c = y)
plt.title('原始資料',pad = 20)
Kmeans聚類
kmeans = KMeans(3)
kmeans.fit(X)
y_ = kmeans.predict(X)
plt.subplot(1,3,2)
plt.scatter(X[:,0],X[:,1],c = y_)
plt.title('KMeans聚類效果',pad = 20)
GMM高斯混合模型聚類
gmm = GaussianMixture(n_components=3)
y_ = gmm.fit_predict(X)
plt.subplot(1,3,3)
plt.scatter(X[:,0],X[:,1],c = y_)
plt.title('GMM聚類效果',pad = 20)
 
plt.figtext(x = 0.51,y = 1.1,s = 'KMeans VS GMM',ha = 'center',fontsize = 30)
plt.savefig('./GMM高斯混合模型.png',dpi = 200)
0
優點:
  • 可以完成大部分形狀的聚類
  • 大資料集時,對噪聲資料不敏感
  • 對於距離或密度聚類,更適合高維特徵
缺點:
  • 計算複雜高,速度較慢
  • 難以對圓形資料聚類
  • 需要在測試前知道類別的個數(成分個數,超引數)
  • 初始化引數會對聚類結果產生影響
參考
2. PRML

相關文章