聚類
聚類主要內容是將樣本進行歸類,同種類別的樣本放到一起,所有樣本最終會形成K個簇,它屬於無監督學習。
核心思想
根據給定的K值和K個初始質心將樣本中每個點都分到距離最近的類簇中,當所有點分配完後根據每個類簇的所有點重新計算質心,一般是通過平均值計算,然後再將每個點分到距離最近的新類簇中,不斷迴圈此操作,直到質心不再變化或達到一定的迭代次數。數學上可以證明k-means是收斂的。
虛擬碼
隨機選擇k個質心,即為簇數
while(true){
計算每個點到最近距離的質心,歸為該類。
重新計算每個類的質心。
if(質心與上一次質心一樣or達到最大迭代次數)
break;
}複製程式碼
缺點
- 需要事先確定類簇的數量。
- 質心的選取會影響最終的聚類結果。
程式碼實現
from numpy import *
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
def kmeans(dataSet, k):
sampleNum, col = dataSet.shape
cluster = mat(zeros((sampleNum, 2)))
centroids = zeros((k, col))
##choose centroids
for i in range(k):
index = int(random.uniform(0, sampleNum))
centroids[i, :] = dataSet[index, :]
clusterChanged = True
while clusterChanged:
clusterChanged = False
for i in range(sampleNum):
minDist = sqrt(sum(power(centroids[0, :] - dataSet[i, :], 2)))
minIndex = 0
for j in range(1,k):
distance = sqrt(sum(power(centroids[j, :] - dataSet[i, :], 2)))
if distance < minDist:
minDist = distance
minIndex = j
if cluster[i, 0] != minIndex:
clusterChanged = True
cluster[i, :] = minIndex, minDist**2
for j in range(k):
pointsInCluster = dataSet[nonzero(cluster[:, 0].A == j)[0]]
centroids[j, :] = mean(pointsInCluster, axis = 0)
return centroids, cluster
dataSet = [[1,1],[3,1],[1,4],[2,5],[11,12],[14,11],[13,12],[11,16],[17,12],[28,10],[26,15],[27,13],[28,11],[29,15]]
dataSet = mat(dataSet)
k = 3
centroids, cluster = kmeans(dataSet, k)
sampleNum, col = dataSet.shape
mark = ['or', 'ob', 'og']
for i in range(sampleNum):
markIndex = int(cluster[i, 0])
plt.plot(dataSet[i, 0], dataSet[i, 1], mark[markIndex])
mark = ['+r', '+b', '+g']
for i in range(k):
plt.plot(centroids[i, 0], centroids[i, 1], mark[i], markersize=12)
plt.show()複製程式碼
結果:
直接用機器學習庫更加方便
from numpy import *
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
dataSet = [[1,1],[3,1],[1,4],[2,5],[11,12],[14,11],[13,12],[11,16],[17,12],[28,10],[26,15],[27,13],[28,11],[29,15]]
dataSet=mat(dataSet)
k = 3
markers = ['^', 'o', 'x']
cls =KMeans(k).fit(dataSet)
for i in range(k):
members=cls.labels_==i
plt.scatter(dataSet[members,0],dataSet[members,1],marker=markers[i])
plt.show()複製程式碼
========廣告時間========
鄙人的新書《Tomcat核心設計剖析》已經在京東銷售了,有需要的朋友可以到 item.jd.com/12185360.ht… 進行預定。感謝各位朋友。
=========================
歡迎關注: