大俠幸會,在下全網同名「演算法金」
0 基礎轉 AI 上岸,多個演算法賽 Top
「日更萬日,讓更多人享受智慧樂趣」
今天我們來聊聊達叔 6 大核心演算法之 —— k-means 演算法。最早由史丹佛大學的 J. B. MacQueen 於 1967 年提出,後來經過許多研究者的改進和發展,成為了一種經典的聚類方法。吳恩達:機器學習的六個核心演算法!
分幾部分,拿下:
- k-means 演算法的基本原理和工作步驟
- 相關的數學公式和程式碼示範
- k-means 演算法的優缺點
- 誤區和注意事項
- k-means 演算法的變種和改進
- k-means 演算法的實際應用
- k-means 演算法與其他聚類演算法的對比
1. k-means 演算法簡介
什麼是 k-means 演算法
k-means 演算法是一種用於聚類分析的非監督學習演算法。它透過將資料點劃分為 k 個簇,使得每個簇中的資料點儘可能相似,而不同簇之間的資料點儘可能不同。這個演算法的名稱來源於其中的 k 個簇(clusters)和每個簇的均值(mean)。
k-means 演算法的工作原理
k-means 演算法的工作原理可以概括為以下幾個步驟:
- 初始化中心點
- 分配樣本到最近的中心點
- 更新中心點
- 迭代直到收斂
下面我們來淺淺的感受一下,走你~
2. k-means 演算法的核心步驟
2.1 初始化中心點
在 k-means 演算法中,第一步是隨機選擇 k 個點作為初始中心點。這個步驟非常重要,因為初始中心點的選擇會影響最終聚類結果的好壞。如果初始中心點選擇不當,可能會導致演算法陷入區域性最優解。
2.2 分配樣本到最近的中心點
一旦初始中心點確定後,我們就可以開始分配樣本了。對於每個資料點,我們計算它到所有中心點的距離,並將其分配到距離最近的中心點所屬的簇中。通常情況下,我們使用歐氏距離來計算資料點之間的距離。
2.3 更新中心點
在所有資料點被分配到最近的中心點後,我們需要重新計算每個簇的中心點。新的中心點是簇中所有資料點的平均值。
2.4 迭代直到收斂
我們不斷重複分配樣本和更新中心點這兩個步驟,直到中心點不再發生變化或達到預設的迭代次數為止。這時,演算法就收斂了,簇的劃分結果也就確定了。
下面,我們用一個結合武俠元素的資料集來演示 k-means 演算法的核心步驟:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
# 生成武俠風格的資料集
np.random.seed(42)
data_A = np.random.normal(loc=[1, 1], scale=0.2, size=(50, 2))
data_B = np.random.normal(loc=[5, 5], scale=0.2, size=(50, 2))
data_C = np.random.normal(loc=[8, 1], scale=0.2, size=(50, 2))
data = np.vstack((data_A, data_B, data_C))
# 使用 k-means 演算法進行聚類
kmeans = KMeans(n_clusters=3, random_state=42)
kmeans.fit(data)
labels = kmeans.labels_
centroids = kmeans.cluster_centers_
# 資料視覺化
plt.scatter(data[:, 0], data[:, 1], c=labels, cmap='viridis', marker='o')
plt.scatter(centroids[:, 0], centroids[:, 1], c='red', marker='x')
plt.xlabel('武功')
plt.ylabel('內力')
plt.title('武俠資料集的聚類結果')
plt.show()
每天一個簡單通透的小案例,如果你對類似於這樣的文章感興趣。
歡迎關注、點贊、轉發~
3. 數學公式和程式碼示範
3.1 距離度量公式
在 k-means 演算法中,最常用的距離度量是歐氏距離。歐氏距離可以衡量兩個資料點之間的相似程度,計算公式如下:
其他距離度量方式有,曼哈頓距離 (Manhattan Distance),切比雪夫距離 (Chebyshev Distance),閔可夫斯基距離 (Minkowski Distance)等
更多細節,見往期微*公號文章:再見!!!KNN
3.2 損失函式(目標函式)
k-means 演算法的目標是最小化簇內資料點與中心點之間的總距離,即最小化下式:
感受一下
3.3 手撕 K means
下面我們用 Python 手動實現 k-means 演算法,並透過視覺化展示演算法的效果。
import numpy as np
import matplotlib.pyplot as plt
# 生成武俠風格的資料集
np.random.seed(7) # (彩蛋)42 是宇宙的答案
data_A = np.random.normal(loc=[1, 1], scale=0.2, size=(50, 2))
data_B = np.random.normal(loc=[5, 5], scale=0.2, size=(50, 2))
data_C = np.random.normal(loc=[8, 1], scale=0.2, size=(50, 2))
data = np.vstack((data_A, data_B, data_C))
def initialize_centroids(data, k):
"""隨機選擇 k 個點作為初始中心點"""
indices = np.random.choice(data.shape[0], k, replace=False)
return data[indices]
def assign_clusters(data, centroids):
"""將每個資料點分配到最近的中心點"""
distances = np.linalg.norm(data[:, np.newaxis] - centroids, axis=2)
return np.argmin(distances, axis=1)
def update_centroids(data, labels, k):
"""重新計算每個簇的中心點"""
new_centroids = np.array([data[labels == i].mean(axis=0) for i in range(k)])
return new_centroids
def kmeans(data, k, max_iters=100):
"""k-means 演算法"""
centroids = initialize_centroids(data, k)
for _ in range(max_iters):
labels = assign_clusters(data, centroids)
new_centroids = update_centroids(data, labels, k)
if np.all(centroids == new_centroids):
break
centroids = new_centroids
return labels, centroids
# 執行 k-means 演算法
k = 3
labels, centroids = kmeans(data, k)
# 資料視覺化
plt.scatter(data[:, 0], data[:, 1], c=labels, cmap='viridis', marker='o')
plt.scatter(centroids[:, 0], centroids[:, 1], c='red', marker='x')
plt.xlabel('武功')
plt.ylabel('內力')
plt.title('武俠資料集的聚類結果')
plt.show()
4. k-means 的優缺點
4.1 k-means 的優勢
- 簡單易懂:k-means 演算法的概念和實現都非常簡單,易於理解和應用。
- 計算效率高:由於演算法的時間複雜度較低,k-means 適合處理大規模資料集。
- 結果直觀:透過視覺化,k-means 聚類結果清晰明瞭,容易解釋。
4.2 k-means 的劣勢
- 需要預設簇數 k:k-means 需要使用者事先指定簇的數量 k,而在實際應用中,合適的 k 值往往很難確定。
- 對初始中心點敏感:k-means 對初始中心點的選擇非常敏感,不同的初始中心點可能導致不同的聚類結果,甚至區域性最優解。
- 只適用於凸形簇:k-means 假設簇是球形的,這使得它難以處理非凸形的簇結構。
- 受異常值影響大:異常值可能會顯著影響中心點的計算,從而影響聚類結果。
5. 誤區和注意事項
5.1 誤區:選擇 k 值的誤區
一個常見的誤區是隨意選擇 k 值。選擇合適的 k 值對於 k-means 演算法的效果至關重要。如果 k 過小,可能會導致欠擬合,無法捕捉資料中的全部資訊;如果 k 過大,可能會導致過擬合,使得模型對資料的細節過於敏感。常用的方法有肘部法(Elbow Method)和輪廓係數法(Silhouette Score)來選擇合適的 k 值。
肘部法(Elbow Method)
肘部法是一種常用的選擇 k 值的方法。其基本思想是透過計算不同 k 值下的總誤差平方和(SSE),繪製 SSE 隨 k 值變化的曲線,當曲線出現“肘部”時,對應的 k 值即為最佳選擇。SSE 隨 k 值增加而遞減,當 k 值達到某個臨界點後,SSE 的減小速度明顯減緩,這個臨界點對應的 k 值就是肘部。
肘部法的步驟如下:
- 執行 k-means 演算法,令 k 從 1 取到最大值。
- 計算每個 k 值對應的 SSE(誤差平方和)。
- 繪製 k 值與 SSE 的關係圖,找出肘部點。
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
def elbow_method(data, max_k):
sse = []
for k in range(1, max_k + 1):
kmeans = KMeans(n_clusters=k, random_state=42)
kmeans.fit(data)
sse.append(kmeans.inertia_)
plt.plot(range(1, max_k + 1), sse, marker='o')
plt.xlabel('簇數 (k)')
plt.ylabel('SSE')
plt.title('肘部法選擇最佳 k 值')
plt.show()
# 使用肘部法選擇最佳 k 值
elbow_method(data, 10)
輪廓係數法(Silhouette Score)
輪廓係數法透過計算資料點的輪廓係數來評估聚類結果的質量。輪廓係數介於 -1 和 1 之間,數值越大表示聚類效果越好。輪廓係數不僅考慮了同一簇內資料點的緊密程度,還考慮了不同簇之間的分離程度。
輪廓係數法的步驟如下:
- 執行 k-means 演算法,令 k 從 2 取到最大值。
- 計算每個 k 值對應的平均輪廓係數。
- 繪製 k 值與平均輪廓係數的關係圖,選擇平均輪廓係數最高的 k 值。
from sklearn.metrics import silhouette_score
def silhouette_method(data, max_k):
silhouette_scores = []
for k in range(2, max_k + 1):
kmeans = KMeans(n_clusters=k, random_state=42)
labels = kmeans.fit_predict(data)
score = silhouette_score(data, labels)
silhouette_scores.append(score)
plt.plot(range(2, max_k + 1), silhouette_scores, marker='o')
plt.xlabel('簇數 (k)')
plt.ylabel('平均輪廓係數')
plt.title('輪廓係數法選擇最佳 k 值')
plt.show()
# 使用輪廓係數法選擇最佳 k 值
silhouette_method(data, 10)
透過這兩種方法,我們可以更加科學和合理地選擇 k 值,從而提高 k-means 演算法的聚類效果。
5.2 注意事項:資料標準化
在使用 k-means 演算法之前,對資料進行標準化處理非常重要。由於不同特徵的量綱不同,直接使用未標準化的資料會導致距離計算時某些特徵的影響被放大。通常情況下,我們使用 z-score 標準化方法:
from sklearn.preprocessing import StandardScaler
# 資料標準化
scaler = StandardScaler()
data_standardized = scaler.fit_transform(data)
# 使用標準化後的資料進行聚類
kmeans = KMeans(n_clusters=3, random_state=42)
labels = kmeans.fit_predict(data_standardized)
centroids = kmeans.cluster_centers_
# 資料視覺化
plt.scatter(data_standardized[:, 0], data_standardized[:, 1], c=labels, cmap='viridis', marker='o')
plt.scatter(centroids[:, 0], centroids[:, 1], c='red', marker='x')
plt.xlabel('標準化後的武功')
plt.ylabel('標準化後的內力')
plt.title('標準化資料的聚類結果')
plt.show()
5.3 誤區:初始中心點選擇的重要性
有些鐵子們可能會忽略初始中心點的選擇,直接使用預設的隨機初始化。其實,初始中心點的選擇會顯著影響聚類結果。為了避免區域性最優解,我們可以使用 k-means++ 演算法進行初始化,這樣可以有效提高演算法的穩定性和收斂速度。
# 使用 k-means++ 初始化進行聚類
kmeans_pp = KMeans(n_clusters=3, init='k-means++', random_state=42)
labels_pp = kmeans_pp.fit_predict(data)
centroids_pp = kmeans_pp.cluster_centers_
# 資料視覺化
plt.scatter(data[:, 0], data[:, 1], c=labels_pp, cmap='viridis', marker='o')
plt.scatter(centroids_pp[:, 0], centroids_pp[:, 1], c='red', marker='x')
plt.xlabel('武功')
plt.ylabel('內力')
plt.title('k-means++ 初始化的聚類結果')
plt.show()
5.4 注意事項:避免區域性最優解
為了進一步避免陷入區域性最優解,可以多次執行 k-means 演算法,並選擇最優的聚類結果。這樣做可以顯著提高最終結果的穩定性和準確性。
# 多次執行 k-means 演算法並選擇最優結果
best_inertia = np.inf
best_labels = None
best_centroids = None
for _ in range(10):
kmeans = KMeans(n_clusters=3, random_state=42)
labels = kmeans.fit_predict(data)
if kmeans.inertia_ < best_inertia:
best_inertia = kmeans.inertia_
best_labels = labels
best_centroids = kmeans.cluster_centers_
# 資料視覺化
plt.scatter(data[:, 0], data[:, 1], c=best_labels, cmap='viridis', marker='o')
plt.scatter(best_centroids[:, 0], best_centroids[:, 1], c='red', marker='x')
plt.xlabel('武功')
plt.ylabel('內力')
plt.title('多次執行後的最佳聚類結果')
plt.show()
6. k-means 演算法的變種和改進
6.1 k-means++ 演算法
k-means++ 是 k-means 演算法的一種改進版本,旨在透過一種更巧妙的初始中心點選擇方法來提高演算法的穩定性和收斂速度。k-means++ 的核心思想是在選擇初始中心點時,讓新的中心點儘可能遠離已選擇的中心點,從而減少隨機初始化帶來的不穩定性。
k-means++ 初始化步驟:
- 隨機選擇一個資料點作為第一個中心點。
- 對於每一個資料點 𝑥𝑥,計算它到最近已選中心點的距離 𝐷(𝑥)𝐷(𝑥)。
- 根據 𝐷(𝑥)𝐷(𝑥) 的機率分佈隨機選擇下一個中心點,選擇機率與 𝐷(𝑥)𝐷(𝑥) 正相關。
- 重複步驟 2 和 3,直到選擇出 k 箇中心點。
6.2 Mini-Batch k-means
Mini-Batch k-means 是 k-means 的另一個改進版本,適用於大規模資料集。它透過使用小批次的資料進行迭代,減少了每次迭代的計算量,從而大大加快了聚類速度。Mini-Batch k-means 的核心思想是每次僅隨機選取一部分資料進行中心點的更新。
6.3 其他變種
除了 k-means++ 和 Mini-Batch k-means 之外,還有許多 k-means 的變種和改進演算法,例如:
- Bisecting k-means:透過遞迴地將資料集分成兩部分來進行聚類,適用於層次聚類。
- Fuzzy k-means:允許一個資料點屬於多個簇,透過模糊隸屬度來表示,適用於模糊聚類。
- Kernel k-means:透過使用核函式將資料對映到高維空間進行聚類,適用於非線性資料。
這些改進演算法在不同的應用場景中具有各自的優勢,可以根據具體需求選擇合適的演算法。
7. k-means 演算法的應用和案例
7.1 影像壓縮
k-means 演算法在影像壓縮中的應用非常廣泛。透過將影像中的畫素點聚類為 k 個顏色簇,可以有效減少影像的顏色數量,從而實現影像壓縮。下面是一個使用 k-means 進行影像壓縮的示例。
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import numpy as np
from skimage import io
# 讀取影像
image = io.imread('https://example.com/image.jpg')
image = np.array(image, dtype=np.float64) / 255
# 將影像資料重塑為二維陣列
w, h, d = image.shape
image_array = np.reshape(image, (w * h, d))
# 使用 k-means 進行影像壓縮
kmeans = KMeans(n_clusters=16, random_state=42).fit(image_array)
labels = kmeans.predict(image_array)
compressed_image = kmeans.cluster_centers_[labels].reshape(w, h, d)
# 顯示原始影像和壓縮後的影像
fig, ax = plt.subplots(1, 2, figsize=(12, 6), subplot_kw={'xticks': [], 'yticks': []})
ax[0].imshow(image)
ax[0].set_title('原始影像')
ax[1].imshow(compressed_image)
ax[1].set_title('壓縮影像')
plt.show()
7.2 客戶分群
在市場營銷中,k-means 演算法可以用來對客戶進行分群,從而更好地制定營銷策略。透過分析客戶的消費行為、偏好等特徵,將客戶分成不同的群體,有助於企業針對不同客戶群體制定個性化的營銷方案。
7.3 其他實際應用
除了影像壓縮和客戶分群,k-means 演算法在其他領域也有廣泛的應用,例如:
- 文件分類:將文件聚類為不同的主題,有助於文件的自動歸檔和檢索。
- 城市規劃:根據居民的地理位置和人口密度,將城市劃分為不同的區域,最佳化城市資源配置。
- 基因表達分析:在生物資訊學中,k-means 用於分析基因表達資料,找出具有相似表達模式的基因群體。
透過這些實際應用的案例,可以看出 k-means 演算法在不同領域的強大實用性。
8. 橫向對比:k-means 與其他聚類演算法
8.1 k-means vs. 層次聚類
原理
- k-means:透過迭代最佳化中心點來最小化簇內平方誤差。
- 層次聚類:透過構建樹狀結構(樹狀圖)來逐步聚合或拆分資料點。
適用場景
- k-means:適用於大規模資料,且簇的形狀是球形的情況。
- 層次聚類:適用於小規模資料,且需要層次結構或簇的形狀不規則的情況。
優缺點對比
- k-means:計算速度快,但對初始點敏感,適合處理大資料。
- 層次聚類:無需預設簇數,但計算複雜度高,不適合大資料。
8.2 k-means vs. DBSCAN
原理
- k-means:基於均值和距離的聚類演算法。
- DBSCAN:基於密度的聚類演算法,透過尋找高密度區域形成簇。
適用場景
- k-means:適用於資料均勻分佈的情況。
- DBSCAN:適用於簇形狀不規則且有噪聲的資料。
優缺點對比
- k-means:需要預設簇數,對異常值敏感。
- DBSCAN:無需預設簇數,能識別噪聲,但引數選擇困難。
8.3 k-means vs. GMM
原理
- k-means:透過最小化簇內平方誤差進行聚類。
- GMM (高斯混合模型):假設資料由多個高斯分佈組成,透過期望最大化(EM)演算法進行聚類。
適用場景
- k-means:適用於簇形狀均勻的資料。
- GMM:適用於簇形狀複雜的資料,能夠處理機率歸屬問題。
優缺點對比
- k-means:簡單高效,但對簇形狀有假設限制。
- GMM:靈活性高,但計算複雜度高,需要選擇適當的高斯分佈數量。
透過這些對比,我們可以看到不同聚類演算法在不同應用場景下的優缺點,選擇合適的演算法可以更好地解決具體問題。
[ 抱個拳,總個結 ]
- 瞭解了 k-means 演算法的基本概念、工作原理和應用場景。
- 學習了 k-means 演算法的核心步驟,包括初始化中心點、分配樣本、更新中心點和迭代直到收斂。
- 掌握了 k-means 演算法的數學公式,如歐氏距離和損失函式,透過程式碼示例加深理解。
- 分析了 k-means 的優缺點,強調了選擇合適 k 值和資料標準化的重要性。
- 探討了 k-means 演算法的變種和改進,如 k-means++ 和 Mini-Batch k-means。
- 透過影像壓縮和客戶分群等案例展示了 k-means 的實際應用效果。
- 比較了 k-means 與其他聚類演算法(如層次聚類、DBSCAN 和 GMM),幫助理解不同演算法的適用場景和優缺點。
希望透過這篇文章,大家能對 k-means 演算法有一個全面的認識,並掌握其實際應用的方法。如果你有任何疑問或需要進一步探討,歡迎隨時留言交流。
吳恩達:機器學習的六個核心演算法!
迴歸演算法,邏輯迴歸,決策樹演算法, 神經網路,K-means(本文),梯度下降(TODO,催更請留言)
- 科研為國分憂,創新與民造福 -
日更時間緊任務急,難免有疏漏之處,還請大俠海涵
內容僅供學習交流之用,部分素材來自網路,侵聯刪
[ 演算法金,碎碎念 ]
全網同名,日更萬日,讓更多人享受智慧樂趣
如果覺得內容有價值,煩請大俠多多 分享、在看、點贊,助力演算法金又猛又持久、很黃很 BL 的日更下去;
同時邀請大俠 關注、星標 演算法金,圍觀日更萬日,助你功力大增、笑傲江湖