9.1.6 DBSCAN聚類演算法————機器學習實戰第二版

_qz發表於2020-11-18

一、工作原理

  1. 對於每個例項,該演算法都會計算在它一小段距離內 ε \varepsilon ε 內有多少個例項。該區域稱為例項的 ε − \varepsilon- ε 鄰域。
  2. 如果一個例項在其 ε \varepsilon ε 鄰域內至少包含 min_samples 個例項(包含自身),則該例項為核心例項。
  3. 核心例項附近的所有例項都屬於同一叢集。這個鄰域可能包括其他核心例項。因此,一長串相鄰的核心例項形成一個叢集。
  4. 任何不是核心例項且鄰居中沒有核心例項的例項都被視為異常

二、引數

sklearn中引數詳解:詳解
兩個重要引數:

  • eps: ε \varepsilon ε 的大小
  • min_samples : 核心例項中至少包含的例項個數

三、變數

sklearn.dataset中的make_moons()函式:連結
make_circles()函式與make_moons()函式相似

from sklearn.cluster import DBSCAN
from sklearn.datasets import make_moons,make_circles
X,y = make_moons(n_samples = 1000,noise = 0.05)

make_moons()生成資料為:
在這裡插入圖片描述
make_circles()生成資料為:
在這裡插入圖片描述

DBSCAN物件的變數:

  • labels_ : 每個例項的叢集標籤。異常例項的叢集標籤為-1
  • core_sample_indices_ : 包含每個例項的索引
  • components_ : 可以得到核心例項本身
X,y = make_moons(n_samples = 100 ,noise = 0.1)
dbscan = DBSCAN(eps = 0.2,min_samples =2).fit(X) 

labels_ = dbscan.labels_
print("標籤為:{}".format(labels_))

len_ = len(dbscan.core_sample_indices_)
print("核心例項個數為:{}".format(len_))

data = dbscan.components_
print("核心例項:{}".format(data))

結果為:

在這裡插入圖片描述

四、程式碼

1.難點

1. 導包

from sklearn.cluster import DBSCAN
from sklearn.datasets import make_moons,make_circles
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd

2. 資料集

X,y = make_moons(n_samples = 1000 ,noise = 0.1)
X2,y2 = make_circles(n_samples = 1000,noise = 0.025)

3. 函式

#引數:eps_為鄰域大小,min_sample_為核心例項鄰域中最小例項數目,X_為資料集;本次用了兩個資料集
def DBscan(eps_,min_sample_,X_):
	#建立模型並訓練
    dbscan = DBSCAN(eps = eps_,min_samples =min_sample_).fit(X_) 
    #繪製散點圖時的c引數,大小和點的個數一樣,即dbscan的變數core_sample_indices的len:len(dbscan.core_sample_indices_)
    mask = np.arange(len(dbscan.core_sample_indices_))
    #mask內為每個非異常例項的叢集索引,異常例項的叢集索引為-1;
    #在dbscan的core_sample_indices和components變數中並沒有出現異常例項,直接被演算法過濾掉了。
    for idx,i in enumerate(dbscan.core_sample_indices_):
        mask[idx]  = dbscan.labels_[i]
    #畫散點圖
    plt.scatter(dbscan.components_[:,0],dbscan.components_[:,1],c = mask)
    #標題
    plt.title("eps = {},min_samples = {}".format(eps_,min_sample_))

4. 呼叫函式

對每個資料集分別除錯兩組引數,第二組引數效果較好,也就是第三列

plt.figure(figsize = (12,8))

plt.subplot(231)
plt.scatter(X[:,0],X[:,1],c = y)
plt.title("Original data")

plt.subplot(232)
DBscan(0.05,5,X)

plt.subplot(233)
DBscan(0.1,5,X)

plt.subplot(234)
plt.scatter(X2[:,0],X2[:,1],c = y2)
plt.title("Original data2")

plt.subplot(235)
DBscan(0.05,7,X2)

plt.subplot(236)
DBscan(0.07,7,X2)

plt.show()

結果為:
在這裡插入圖片描述

相關文章