【火爐煉AI】機器學習022-使用均值漂移聚類演算法構建模型
(本文所使用的Python庫和版本號: Python 3.5, Numpy 1.14, scikit-learn 0.19, matplotlib 2.2 )
無監督學習演算法有很多種,前面已經講解過了K-means聚類演算法,並用該演算法對圖片進行向量量化壓縮。下面我們來學習第二種無監督學習演算法----均值漂移演算法。
1. 均值漂移演算法簡介
均值漂移演算法是一種基於密度梯度上升的非引數方法,它經常被應用在影象識別中的目標跟蹤,資料聚類,分類等場景。
其核心思想是:首先隨便選擇一箇中心點,然後計算該中心點一定範圍之內所有點到中心點的距離向量的平均值,計算該平均值得到一個偏移均值,然後將中心點移動到偏移均值位置,通過這種不斷重複的移動,可以使中心點逐步逼近到最佳位置。這種思想類似於梯度下降方法,通過不斷的往梯度下降的方向移動,可以到達梯度上的區域性最優解或全域性最優解。
如下是漂移均值演算法的思想呈現,首先隨機選擇一箇中心點(綠色點),然後計算該點一定範圍內所有點到這個點的距離均值,然後將該中心點移動距離均值,到黃色點處,同理,再計算該黃色點一定範圍內的所有點到黃點的距離均值,經過多次計算均值--移動中心點等方式,可以使得中心點逐步逼近最佳中心點位置,即圖中紅色點處。
1.1 均值漂移演算法的基礎公式
從上面核心思想可以看出,均值漂移的過程就是不斷的重複計算距離均值,移動中心點的過程,故而計算偏移均值和移動距離便是非常關鍵的兩個步驟,如下為計算偏移均值的基礎公式。
其中Sh:以x為中心點,半徑為h的高維球區域; k:包含在Sh範圍內點的個數; xi:包含在Sh範圍內的點
第二個步驟是計算移動一定距離之後的中心點位置,其計算公式為:
其中,Mt為t狀態下求得的偏移均值; xt為t狀態下的中心
很顯然,移動之後的中心點位置是移動前位置加上偏移均值。
1.2 引入核函式的偏移均值演算法
上述雖然介紹了均值漂移演算法的基礎公式,但是該公式存在一定的問題,我們知道,高維球區域內的所有樣本點對求解的貢獻是不一樣的,而基礎公式卻當做貢獻一樣來處理,即所有點的權重一樣,這是不符合邏輯的,那麼怎麼改進了?我們可以引入核函式,用來求出每個樣本點的貢獻權重。當然這種求解權重的核函式有很多種,高斯函式就是其中的一種,如下公式是引入高斯核函式後的偏移均值的計算公式:
上面就是核函式內部的樣子。
1.3 均值漂移演算法的運算步驟
均值漂移演算法的應用非常廣泛,比如在聚類,影象分割,目標跟蹤等領域,其運算步驟往往包含有如下幾個步驟:
1,在資料點中隨機選擇一個點作為初始中心點。
2,找出離該中心點距離在頻寬之內的所有點,記做集合M,認為這些點屬於簇C.
3,計算從中心點開始到集合M中每個元素的向量,將這些向量相加,得到偏移向量。
4,將該中心點沿著偏移的方向移動,移動距離就是該偏移向量的模。
5,重複上述步驟2,3,4,直到偏移向量的大小滿足設定的閾值要求,記住此時的中心點。
6,重複上述1,2,3,4,5直到所有的點都被歸類。
7,分類:根據每個類,對每個點的訪問頻率,取訪問頻率最大的那個類,作為當前點集的所屬類。
1.4 均值漂移演算法的優勢
均值漂移演算法用於叢集資料點時,把資料點的分佈看成是概率密度函式,希望在特徵空間中根據函式分佈特徵找出資料點的模式,這些模式就對應於一群群區域性最密集分佈的點。
雖然我們前面講解了K-means演算法,但K-means演算法在實際應用時,需要知道我們要把資料劃分為幾個類別,如果類別數量出錯,則往往難以得到令人滿意的分類結果,而要劃分的類別往往很難事先確定。這就是K-means演算法的應用難點。
而均值漂移演算法卻不需要事先知道要叢集的數量,這種演算法可以在我們不知道要尋找多少叢集的情況下自動劃分最合適的族群,這就是均值漂移演算法的一個很明顯優勢。
以上部分內容來源於部落格文章,在此表示感謝。
2. 構建均值漂移模型來聚類資料
本文所使用的資料集和讀取資料集的方式與上一篇文章【火爐煉AI】機器學習020-使用K-means演算法對資料進行聚類分析一模一樣,故而此處省略。
下面是構建MeanShift物件的程式碼,使用MeanShift之前,我們需要評估頻寬,頻寬就是上面所講到的距離中心點的一定距離,我們要把所有包含在這個距離之內的點都放入一個集合M中,用於計算偏移向量。
# 構建MeanShift物件,但需要評估頻寬
from sklearn.cluster import MeanShift, estimate_bandwidth
bandwidth=estimate_bandwidth(dataset_X,quantile=0.1,
n_samples=len(dataset_X))
meanshift=MeanShift(bandwidth=bandwidth,bin_seeding=True) # 構建物件
meanshift.fit(dataset_X) # 並用MeanShift物件來訓練該資料集
centroids=meanshift.cluster_centers_ # 質心的座標,對應於feature0, feature1
print(centroids) # 可以看出有4行,即4個質心
labels=meanshift.labels_ # 資料集中每個資料點對應的label
# print(labels)
cluster_num=len(np.unique(labels)) # label的個數,即自動劃分的族群的個數
print('cluster num: {}'.format(cluster_num))
複製程式碼
-------------------------------------輸---------出----------------
[[ 8.22338235 1.34779412]
[ 4.10104478 -0.81164179]
[ 1.18820896 2.10716418]
[ 4.995 4.99967742]]
cluster num: 4
--------------------------------------------完--------------------
可以看出,此處我們得到了四個質心,這四個質心的座標位置可以通過meanshift.cluster_centers_獲取,而meanshift.labels_ 得到的就是原來樣本資料的label,也就是我們通過均值漂移演算法自己找到的label,這就是無監督學習的優勢所在:雖然沒有給樣本資料指定label,但是該演算法能自己找到其對應的label。
同樣的,該怎麼檢視該MeanShift演算法的好壞了,可以通過下面的函式直接觀察資料集劃分的效果。
def visual_meanshift_effect(meanshift,dataset):
assert dataset.shape[1]==2,'only support dataset with 2 features'
X=dataset[:,0]
Y=dataset[:,1]
X_min,X_max=np.min(X)-1,np.max(X)+1
Y_min,Y_max=np.min(Y)-1,np.max(Y)+1
X_values,Y_values=np.meshgrid(np.arange(X_min,X_max,0.01),
np.arange(Y_min,Y_max,0.01))
# 預測網格點的標記
predict_labels=meanshift.predict(np.c_[X_values.ravel(),Y_values.ravel()])
predict_labels=predict_labels.reshape(X_values.shape)
plt.figure()
plt.imshow(predict_labels,interpolation='nearest',
extent=(X_values.min(),X_values.max(),
Y_values.min(),Y_values.max()),
cmap=plt.cm.Paired,
aspect='auto',
origin='lower')
# 將資料集繪製到圖表中
plt.scatter(X,Y,marker='v',facecolors='none',edgecolors='k',s=30)
# 將中心點繪製到圖中
centroids=meanshift.cluster_centers_
plt.scatter(centroids[:,0],centroids[:,1],marker='o',
s=100,linewidths=2,color='k',zorder=5,facecolors='b')
plt.title('MeanShift effect graph')
plt.xlim(X_min,X_max)
plt.ylim(Y_min,Y_max)
plt.xlabel('feature_0')
plt.ylabel('feature_1')
plt.show()
visual_meanshift_effect(meanshift,dataset_X)
複製程式碼
########################小**********結###################
1,MeanShift的構建和訓練方法和K-means的方式幾乎一樣,但是MeanShift可以自動計算出資料集的族群數量,而不需要人為事先指定,這使得MeanShift比K-means要好用一些。
2, 訓練之後的MeanShift物件中包含有該資料集的質心座標,資料集的各個樣本對應的label資訊,這些資訊可以很方便的獲取。
#######################################################
注:本部分程式碼已經全部上傳到(我的github)上,歡迎下載。
參考資料:
1, Python機器學習經典例項,Prateek Joshi著,陶俊傑,陳小莉譯