[譯] Sklearn 中的樸素貝葉斯分類器

sisibeloved發表於2018-08-28

[譯] Sklearn 中的樸素貝葉斯分類器

用豆機實現的高斯分佈

這篇教程詳述了樸素貝葉斯分類器的演算法、它的原理優缺點,並提供了一個使用 Sklearn 庫的示例。

背景

以著名的泰坦尼克號遇難者資料集為例。它收集了泰坦尼克號的乘客的個人資訊以及是否從那場海難中生還。讓我們試著用乘客的船票費用來預測一下他能否生還。

[譯] Sklearn 中的樸素貝葉斯分類器

泰坦尼克號上的 500 名乘客

假設你隨機取了 500 名乘客。在這些樣本中,30% 的人倖存下來。倖存乘客的平均票價為 100 美元,而遇難乘客的平均票價為 50 美元。現在,假設你有了一個新的乘客。你不知道他是否倖存,但你知道他買了一張 30 美元的票穿越大西洋。請你預測一下這個乘客是否倖存。

原理

好吧,你可能回答說這個乘客沒能倖存。為什麼?因為根據上文所取的乘客的隨機子集中所包含的資訊,本來的生還機率就很低(30%),而窮人的生還機率則更低。你會把這個乘客放在最可能的組別(低票價組)。這就是樸素貝葉斯分類器所要實現的。

分析

樸素貝葉斯分類器利用條件概率來聚集資訊,並假設特徵之間相對獨立。這是什麼意思呢?舉個例子,這意味著我們必須假定泰坦尼克號的房間舒適度與票價無關。顯然這個假設是錯誤的,這就是為什麼我們將這個假設稱為樸素(Naive)的原因。樸素假設使得計算得以簡化,即使在非常大的資料集上也是如此。讓我們來一探究竟。

樸素貝葉斯分類器本質上是尋找能描述給定特徵條件下屬於某個類別的概率的函式,這個函式寫作 P(Survival | f1,…, fn)。我們使用貝葉斯定理來簡化計算:

[譯] Sklearn 中的樸素貝葉斯分類器

式 1:貝葉斯定理

P(Survival) 很容易計算,而我們構建分類器也不需要用到 P(f1,…, fn),因此問題回到計算 P(f1,…, fn | Survival) 上來。我們應用條件概率公式來再一次簡化計算:

[譯] Sklearn 中的樸素貝葉斯分類器

式 2:初步擴充

上式最後一行的每一項的計算都需要一個包含所有條件的資料集。為了計算 {Survival, f_1, …, f_n-1} 條件下 fn 的概率(即 P(fn | Survival, f_1, …, f_n-1)),我們需要有足夠多不同的滿足條件 {Survival, f_1, …, f_n-1} 的 fn 值。這會需要大量的資料,並導致維度災難。這時樸素假設(Naive Assumption)的好處就凸顯出來了。假設特徵是獨立的,我們可以認為條件 {Survival, f_1, …, f_n-1} 的概率等於 {Survival} 的概率,以此來簡化計算:

[譯] Sklearn 中的樸素貝葉斯分類器

式 3:應用樸素假設

最後,為了分類,新建一個特徵向量,我們只需要選擇是否生還的值(1 或 0),令 P(f1, …, fn|Survival) 最高,即為最終的分類結果:

[譯] Sklearn 中的樸素貝葉斯分類器

式 4:argmax 分類器

注意:常見的錯誤是認為分類器輸出的概率是對的。事實上,樸素貝葉斯被稱為差估計器,所以不要太認真地看待這些輸出概率。

找出合適的分佈函式

最後一步就是實現分類器。怎樣為概率函式 P(f_i| Survival) 建立模型呢?在 Sklearn 庫中有三種模型:

[譯] Sklearn 中的樸素貝葉斯分類器

正態分佈

[譯] Sklearn 中的樸素貝葉斯分類器

二項式分佈

Python 程式碼

接下來,基於泰坦尼克遇難者資料集,我們實現了一個經典的高斯樸素貝葉斯。我們將使用船艙等級、性別、年齡、兄弟姐妹數目、父母/子女數量、票價和登船口岸這些資訊。

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import time
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB, BernoulliNB, MultinomialNB

# 匯入資料集
data = pd.read_csv("data/train.csv")

# 將分類變數轉換為數字
data["Sex_cleaned"]=np.where(data["Sex"]=="male",0,1)
data["Embarked_cleaned"]=np.where(data["Embarked"]=="S",0,
                                  np.where(data["Embarked"]=="C",1,
                                           np.where(data["Embarked"]=="Q",2,3)
                                          )
                                 )
# 清除資料集中的非數字值(NaN)
data=data[[
    "Survived",
    "Pclass",
    "Sex_cleaned",
    "Age",
    "SibSp",
    "Parch",
    "Fare",
    "Embarked_cleaned"
]].dropna(axis=0, how='any')

# 將資料集拆分成訓練集和測試集
X_train, X_test = train_test_split(data, test_size=0.5, random_state=int(time.time()))
複製程式碼
# 例項化分類器
gnb = GaussianNB()
used_features =[
    "Pclass",
    "Sex_cleaned",
    "Age",
    "SibSp",
    "Parch",
    "Fare",
    "Embarked_cleaned"
]

# 訓練分類器
gnb.fit(
    X_train[used_features].values,
    X_train["Survived"]
)
y_pred = gnb.predict(X_test[used_features])

# 列印結果
print("Number of mislabeled points out of a total {} points : {}, performance {:05.2f}%"
      .format(
          X_test.shape[0],
          (X_test["Survived"] != y_pred).sum(),
          100*(1-(X_test["Survived"] != y_pred).sum()/X_test.shape[0])
))
複製程式碼

Number of mislabeled points out of a total 357 points: 68, performance 80.95%

這個分類器的正確率為 80.95%

使用單個特徵說明

讓我們試著只使用票價資訊來約束分類器。下面我們計算 P(Survival = 1) 和 P(Survival = 0) 的概率:

mean_survival=np.mean(X_train["Survived"])
mean_not_survival=1-mean_survival
print("Survival prob = {:03.2f}%, Not survival prob = {:03.2f}%"
      .format(100*mean_survival,100*mean_not_survival))
複製程式碼

Survival prob = 39.50%, Not survival prob = 60.50%

然後,根據式 3,我們只需要得出概率分佈函式 P(fare| Survival = 0) 和 P(fare| Survival = 1)。我們選用高斯樸素貝葉斯分類器,因此,必須假設資料按高斯分佈。

[譯] Sklearn 中的樸素貝葉斯分類器

式 5:高斯公式(σ:標準差 / μ:均值)

然後,我們需要算出是否生還值不同的情況下,票價資料集的均值和標準差。我們得到以下結果:

mean_fare_survived = np.mean(X_train[X_train["Survived"]==1]["Fare"])
std_fare_survived = np.std(X_train[X_train["Survived"]==1]["Fare"])
mean_fare_not_survived = np.mean(X_train[X_train["Survived"]==0]["Fare"])
std_fare_not_survived = np.std(X_train[X_train["Survived"]==0]["Fare"])

print("mean_fare_survived = {:03.2f}".format(mean_fare_survived))
print("std_fare_survived = {:03.2f}".format(std_fare_survived))
print("mean_fare_not_survived = {:03.2f}".format(mean_fare_not_survived))
print("std_fare_not_survived = {:03.2f}".format(std_fare_not_survived))
複製程式碼
mean_fare_survived = 54.75
std_fare_survived = 66.91
mean_fare_not_survived = 24.61
std_fare_not_survived = 36.29
複製程式碼

讓我們看看關於生還未生還的直方圖的結果分佈:

[譯] Sklearn 中的樸素貝葉斯分類器

圖 1:各個是否生還值的票價直方圖和高斯分佈(縮放等級並不對應)

可以發現,分佈與資料集並沒有很好地擬合。在實現模型之前,最好驗證特徵分佈是否遵循上述三種模型中的一種。如果連續特徵不具有正態分佈,則應使用變換或不同的方法將其轉換成正態分佈。為了便於說明,這我們將分佈看作是正態的。應用式 1 貝葉斯定理,可得以下這個分類器:

[譯] Sklearn 中的樸素貝葉斯分類器

圖 2:高斯分類器

如果票價分類器的值超過 78(classifier(Fare) ≥ ~78),則 P(fare| Survival = 1) ≥ P(fare| Survival = 0),我們將這個人歸類為生還。否則我們就將他歸為未生還。我們得到了一個正確率為 64.15% 的分類器。

如果我們在同一資料集上訓練 Sklearn 高斯樸素貝葉斯分類器,將會得到完全相同的結果:

from sklearn.naive_bayes import GaussianNB
gnb = GaussianNB()
used_features =["Fare"]
y_pred = gnb.fit(X_train[used_features].values, X_train["Survived"]).predict(X_test[used_features])
print("Number of mislabeled points out of a total {} points : {}, performance {:05.2f}%"
      .format(
          X_test.shape[0],
          (X_test["Survived"] != y_pred).sum(),
          100*(1-(X_test["Survived"] != y_pred).sum()/X_test.shape[0])
))
print("Std Fare not_survived {:05.2f}".format(np.sqrt(gnb.sigma_)[0][0]))
print("Std Fare survived: {:05.2f}".format(np.sqrt(gnb.sigma_)[1][0]))
print("Mean Fare not_survived {:05.2f}".format(gnb.theta_[0][0]))
print("Mean Fare survived: {:05.2f}".format(gnb.theta_[1][0]))
複製程式碼
Number of mislabeled points out of a total 357 points: 128, performance 64.15%
Std Fare not_survived 36.29
Std Fare survived: 66.91
Mean Fare not_survived 24.61
Mean Fare survived: 54.75
複製程式碼

樸素貝葉斯分類器的優缺點

優點:

  • 計算迅速
  • 實現簡單
  • 在小資料集上表現良好
  • 在高維度資料上表現良好
  • 即使樸素假設沒有完全滿足,也能表現良好。在許多情況下,建立一個好的分類器只需要近似的資料就夠了。

缺點:

  • 需要移除相關特徵,因為它們會在模型中被計算兩次,這將導致該特徵的重要性被高估。
  • 如果測試集中,某分類變數的一個類別沒有在訓練集中出現過,那麼模型會把這種情況設為零概率。它將無法做出預測。這通常被稱為『零位頻率』。我們可以使用平滑技術來解決這個問題。最簡單的平滑技術之一稱為拉普拉斯平滑。當你訓練一個樸素貝葉斯分類器時,Sklearn 會預設使用拉普拉斯平滑演算法。

結語

非常感謝你閱讀這篇文章。我希望它能幫助你理解樸素貝葉斯分類器的概念以及它的優點

致謝 Antoine ToubhansFlavian HautboisAdil BaajRaphaël Meudec

如果發現譯文存在錯誤或其他需要改進的地方,歡迎到 掘金翻譯計劃 對譯文進行修改並 PR,也可獲得相應獎勵積分。文章開頭的 本文永久連結 即為本文在 GitHub 上的 MarkDown 連結。


掘金翻譯計劃 是一個翻譯優質網際網路技術文章的社群,文章來源為 掘金 上的英文分享文章。內容覆蓋 AndroidiOS前端後端區塊鏈產品設計人工智慧等領域,想要檢視更多優質譯文請持續關注 掘金翻譯計劃官方微博知乎專欄

相關文章