機器學習之決策樹原理和sklearn實踐

Fate0729發表於2019-06-24

1. 場景描述

時間:早上八點,地點:婚介所

‘閨女,我有給你找了個合適的物件,今天要不要見一面?’

‘多大?’ ‘26歲’

‘長的帥嗎?’ ‘還可以,不算太帥’

‘工資高嗎?’ ‘略高於平均水平’

‘會寫程式碼嗎?’ ‘人家是程式設計師,程式碼寫的棒著呢!’

‘好,把他的聯絡方式發過來吧,我抽空見一面’

上面的場景描述摘抄自,是一個典型的決策樹分類問題,通過年齡、長相、工資、是否會程式設計等特徵屬性對介紹物件進行是否約會進行分類

決策樹是一種自上而下,對樣本資料進行樹形分類的過程,由結點和有向邊組成,每個結點(葉結點除外)便是一個特徵或屬性,葉結點表示類別。從頂部根結點開始,所有樣本聚在儀器,經過根結點的劃分,樣本被分到不同的子結點中。再根據子結點的特徵進一步劃分,直至樣本都被分到某一類別(葉子結點)中

2. 決策樹原理

決策樹作為最基礎、最常見的有監督學習模型,常被用於分類問題和迴歸問題,將決策樹應用整合思想可以得到隨機森林、梯度提升決策樹等模型。其主要優點是模型具有可讀性,分類速度快。決策樹的學習通常包括三個步驟:特徵選擇、決策樹的生成和決策樹的修剪,下面對特徵選擇演算法進行描述和區別

2.1 ID3---最大資訊增益

在資訊理論與概率統計中,熵(entropy)是表示隨機變數不確定性的度量,設X是一個取有限個值的隨機變數,其概率分佈為:\[P(X=X_i)=P_i (i = 1,2,...,n)\],則隨機變數X的熵定義為:\[H(X) = -\sum_{i=1}^np_i\log{p_i}\]表示式中的對數以2為底或以e為底,這時熵的單位分別稱作bit或nat,從表示式可以看出X的熵與X的取值無關,所以X的熵也記作\(H(p)\),即\[H(p) = -\sum_{x=1}^np_i\log{p_i}\]熵取值越大,隨機變數的不確定性越大

條件熵:

條件熵H(Y|X)表示在已知隨機變數X的條件下,隨機變數Y的不確定性,隨機變數X給定的條件下隨機變數Y的條件熵定義為X給定條件下Y的條件概率分佈的熵對X的數學期望\[H(Y|X) = \sum_{i=1}^nP(X=X_i)H(Y|X=X_i)\]

資訊增益:\[g(D,A) = H(D) - H(D|A)\]

import pandas as pd
data = {
        '年齡':['老','年輕','年輕','年輕','年輕'],
        '長相':['帥','一般','醜','一般','一般'],
        '工資':['高','中等','高','高','低'],
        '寫程式碼':['不會','會','不會','會','不會'],
        '類別':['不見','見','不見','見','不見']}
frame = pd.DataFrame(data,index=['小A','小B','小C','小D','小L'])
print(frame)
    年齡  長相  工資 寫程式碼  類別
小A   老   帥   高  不會  不見
小B  年輕  一般  中等   會   見
小C  年輕   醜   高  不會  不見
小D  年輕  一般   高   會   見
小L  年輕  一般   低  不會  不見
import math
print(math.log(3/5))
print('H(D):',-3/5 *math.log(3/5,2) - 2/5*math.log(2/5,2))
print('H(D|年齡)',1/5*math.log(1,2)+4/5*(-1/2*math.log(1/2,2)-1/2*math.log(1/2,2)))
print('以同樣的方法計算H(D|長相),H(D|工資),H(D|寫程式碼)')
print('H(D|長相)',0.551)
print('H(D|工資)',0.551)
print('H(D|寫程式碼)',0)
-0.5108256237659907
H(D): 0.9709505944546686
H(D|年齡) 0.8
以同樣的方法計算H(D|長相),H(D|工資),H(D|寫程式碼)
H(D|長相) 0.551
H(D|工資) 0.551
H(D|寫程式碼) 0

計算資訊增益:g(D,寫程式碼)=0.971最大,可以先按照寫程式碼來拆分決策樹

2.2 C4.5---最大資訊增益比

以資訊增益作為劃分訓練資料集的特徵,存在偏向於選擇取值較多的問題,使用資訊增益比可以對對著問題進行校正,這是特徵選擇的另一標準
資訊增益比定義為其資訊增益g(D,A)與訓練資料集D關於特徵A的值的熵\(H_A(D)\)之比:\[g_R(D,A) = \frac{g(D,A)}{H_A(D)}\]

\[H_A(D) = -\sum_{i=1}^n\frac{|D_i|}{|D|}\log\frac{|D_i|}{|D|}\]

拿上面ID3的例子說明:
\[H_年齡(D) = -1/5*math.log(1/5,2)-4/5*math.log(4/5,2)\]

\[g_R(D,年齡) = H_{年齡}(D)/g(D,年齡) = 0.171/0.722 = 0.236 \]

2.3 CART----最大基尼指數(Gini)

Gini描述的是資料的純度,與資訊熵含義類似,分類問題中,假設有K個類,樣本點資料第k類的概率為\(P_k\),則概率分佈的基尼指數定義為:
\[Gini(p) = 1- \sum_{k=1}^Kp_k(1-p_k) = 1 - \sum_{k=1}^Kp_{k}^2\]
對於二分類問題,弱樣本點屬於第1個類的概率是p,則概率分佈的基尼指數為\[Gini(p) = 2p(1-p)\],對於給定的樣本幾何D,其基尼指數為\[Gini(D) = 1 - \sum_{k=1}^K[\frac{|C_k|}{|D|}]^2\]注意這裡\(C_k\)是D種屬於第k類的樣本子集,K是類的個數,如果樣本幾個D根據特徵A是否取某一可能指a被分割成D1和D2兩部分,則在特徵A的條件下,集合D的基尼指數定義為\[Gini(D,A) = \frac{|D_1|}{|D|}Gini(D_1)+\frac{|D_2|}{|D|}Gini(D_2)\]
\[Gini(D|年齡=老)=1/5*(1-1)+4/5*[1-(1/2*1/2+1/2*1/2)] = 0.4\]

CART在每一次迭代種選擇基尼指數最小的特徵及其對應的切分點進行分類

2.4 ID3、C4.5與Gini的區別

2.4.1 從樣本型別角度

從樣本型別角度,ID3只能處理離散型變數,而C4.5和CART都處理連續性變數,C4.5處理連續性變數時,通過對資料排序之後找到類別不同的分割線作為切割點,根據切分點把連續型數學轉換為bool型,從而將連續型變數轉換多個取值區間的離散型變數。而對於CART,由於其構建時每次都會對特徵進行二值劃分,因此可以很好地適合連續性變數。

2.4.2 從應用角度

ID3和C4.5只適用於分類任務,而CART既可以用於分類也可以用於迴歸

2.4.3 從實現細節、優化等角度

ID3對樣本特徵缺失值比較敏感,而C4.5和CART可以對缺失值進行不同方式的處理,ID3和C4.5可以在每個結點熵產生出多叉分支,且每個特徵在層級之間不會複用,而CART每個結點只會產生兩個分支,因此會形成一顆二叉樹,且每個特徵可以被重複使用;ID3和C4.5通過剪枝來權衡樹的準確性和泛化能力,而CART直接利用全部資料發現所有可能的樹結構進行對比。

3. 決策樹的剪枝

3.1 為什麼要進行剪枝?

對決策樹進行剪枝是為了防止過擬合

根據決策樹生成演算法通過訓練資料集生成了複雜的決策樹,導致對於測試資料集出現了過擬合現象,為了解決過擬合,就必須考慮決策樹的複雜度,對決策樹進行剪枝,剪掉一些枝葉,提升模型的泛化能力

決策樹的剪枝通常由兩種方法,預剪枝和後剪枝

3.2 預剪枝

預剪枝的核心思想是在樹中結點進行擴充套件之前,先計算當前的劃分是否能帶來模型泛化能力的提升,如果不能,則不再繼續生長子樹。此時可能存在不同類別的樣本同時存於結點中,按照多數投票的原則判斷該結點所屬類別。預剪枝對於何時停止決策樹的生長有以下幾種方法

  • (1)當樹達到一定深度的時候,停止樹的生長
  • (2)當葉結點數到達某個閾值的時候,停止樹的生長
  • (3)當到達結點的樣本數量少於某個閾值的時候,停止樹的生長
  • (4)計算每次分裂對測試集的準確度提升,當小於某個閾值的時候,不再繼續擴充套件

預剪枝思想直接,演算法簡單,效率高特點,適合解決大規模問題。但如何準確地估計何時停止樹的生長,針對不同問題會有很大差別,需要一定的經驗判斷。且預剪枝存在一定的侷限性,有欠擬合的風險

3.3 後剪枝

後剪枝的核心思想是讓演算法生成一顆完全生長的決策樹,然後從底層向上計算是否剪枝。剪枝過程將子樹刪除,用一個葉結點代替,該結點的類別同樣按照多數投票原則進行判斷。同樣地,後剪枝葉可以通過在測試集上的準確率進行判斷,如果剪枝過後的準確率有所提升,則進行剪枝,後剪枝方法通常可以得到泛化能力更強的決策樹,但時間開銷更大

損失函式

\[C_a(T) = \sum_{t=1}^{|T|}N_tH_t(T) + a|T|\]

\(其中|T|為葉結點個數,N_t為結點t的樣本個數,H_t(T)為結點t的資訊熵,a|T|為懲罰項,a>=0\)

\[C_a(T) = \sum_{t=1}^{|T|}N_tH_t(T) + a|T| = -\sum_{t=1}^{|T|}\sum_{k=1}^KN_{tk}\log \frac{N_{tk}}{N_t} + a|T|\]

注意:上面的公式中是\(N_{tk}\log \frac{N_{tk}}{N_t}\),而不是\(\frac{N_{tk}}{N_t} \log \frac{N_{tk}}{N_t}\)

令:\[C_a(T) = C(T) + a|T|\]

\(C(T)\)表示模型對訓練資料的預測誤差,即模型與訓練資料的擬合程度,|T|表示模型複雜度,引數a>=0控制兩者的影響力,較大的a促使選擇較簡單的模型,較小的a促使選擇複雜的模型,a=0意味著只考慮模型與訓練資料的擬合程度,不考慮模型的複雜度

4. 使用sklearn庫為衛星資料集訓練並微調一個決策樹

4.1 需求

  • a.使用make_moons(n_samples=10000,noise=0.4)生成一個衛星資料集
  • b.使用train_test_split()拆分訓練集和測試集
  • c.使用交叉驗證的網格搜尋為DecisionTreeClassifier找到合適的超引數,提示:嘗試max_leaf_nodes的多種值
  • d.使用超引數對整個訓練集進行訓練,並測量模型測試集上的效能

程式碼實現

from sklearn.datasets import make_moons
import numpy as np
import pandas as pd
dataset = make_moons(n_samples=10000,noise=0.4)
print(type(dataset))
print(dataset)
<class 'tuple'>
(array([[ 0.24834453, -0.11160162],
       [-0.34658051, -0.43774172],
       [-0.25009951, -0.80638312],
       ...,
       [ 2.3278198 ,  0.39007769],
       [-0.77964208,  0.68470383],
       [ 0.14500963,  1.35272533]]), array([1, 1, 1, ..., 1, 0, 0], dtype=int64))
dataset_array = np.array(dataset[0])
label_array = np.array(dataset[1])
print(dataset_array.shape,label_array.shape)
(10000, 2) (10000,)
# 拆分資料集
from sklearn.model_selection import train_test_split
x_train,x_test = train_test_split(dataset_array,test_size=0.2,random_state=42)
print(x_train.shape,x_test.shape)
y_train,y_test = train_test_split(label_array,test_size=0.2,random_state=42)
print(y_train.shape,y_test.shape)
(8000, 2) (2000, 2)
(8000,) (2000,)
# 使用交叉驗證的網格搜尋為DecisionTreeClassifier找到合適的超引數
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import GridSearchCV

decisionTree = DecisionTreeClassifier(criterion='gini')
param_grid = {'max_leaf_nodes': [i for i in range(2,10)]}
gridSearchCV = GridSearchCV(decisionTree,param_grid=param_grid,cv=3,verbose=2)
gridSearchCV.fit(x_train,y_train)
Fitting 3 folds for each of 8 candidates, totalling 24 fits
[CV] max_leaf_nodes=2 ................................................
[CV] ................................. max_leaf_nodes=2, total=   0.0s
[CV] max_leaf_nodes=2 ................................................
[CV] ................................. max_leaf_nodes=2, total=   0.0s
[CV] max_leaf_nodes=2 ................................................
[CV] ................................. max_leaf_nodes=2, total=   0.0s
[CV] max_leaf_nodes=3 ................................................
[CV] ................................. max_leaf_nodes=3, total=   0.0s
[CV] max_leaf_nodes=3 ................................................
[CV] ................................. max_leaf_nodes=3, total=   0.0s
[CV] max_leaf_nodes=3 ................................................
[CV] ................................. max_leaf_nodes=3, total=   0.0s
[CV] max_leaf_nodes=4 ................................................
[CV] ................................. max_leaf_nodes=4, total=   0.0s
[CV] max_leaf_nodes=4 ................................................
[CV] ................................. max_leaf_nodes=4, total=   0.0s
[CV] max_leaf_nodes=4 ................................................
[CV] ................................. max_leaf_nodes=4, total=   0.0s
[CV] max_leaf_nodes=5 ................................................
[CV] ................................. max_leaf_nodes=5, total=   0.0s
[CV] max_leaf_nodes=5 ................................................
[CV] ................................. max_leaf_nodes=5, total=   0.0s
[CV] max_leaf_nodes=5 ................................................
[CV] ................................. max_leaf_nodes=5, total=   0.0s
[CV] max_leaf_nodes=6 ................................................
[CV] ................................. max_leaf_nodes=6, total=   0.0s
[CV] max_leaf_nodes=6 ................................................
[CV] ................................. max_leaf_nodes=6, total=   0.0s
[CV] max_leaf_nodes=6 ................................................
[CV] ................................. max_leaf_nodes=6, total=   0.0s
[CV] max_leaf_nodes=7 ................................................
[CV] ................................. max_leaf_nodes=7, total=   0.0s
[CV] max_leaf_nodes=7 ................................................
[CV] ................................. max_leaf_nodes=7, total=   0.0s
[CV] max_leaf_nodes=7 ................................................
[CV] ................................. max_leaf_nodes=7, total=   0.0s
[CV] max_leaf_nodes=8 ................................................
[CV] ................................. max_leaf_nodes=8, total=   0.0s
[CV] max_leaf_nodes=8 ................................................
[CV] ................................. max_leaf_nodes=8, total=   0.0s
[CV] max_leaf_nodes=8 ................................................
[CV] ................................. max_leaf_nodes=8, total=   0.0s
[CV] max_leaf_nodes=9 ................................................
[CV] ................................. max_leaf_nodes=9, total=   0.0s
[CV] max_leaf_nodes=9 ................................................
[CV] ................................. max_leaf_nodes=9, total=   0.0s
[CV] max_leaf_nodes=9 ................................................
[CV] ................................. max_leaf_nodes=9, total=   0.0s


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  24 out of  24 | elapsed:    0.0s finished

GridSearchCV(cv=3, error_score='raise-deprecating',
       estimator=DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best'),
       fit_params=None, iid='warn', n_jobs=None,
       param_grid={'max_leaf_nodes': [2, 3, 4, 5, 6, 7, 8, 9]},
       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
       scoring=None, verbose=2)
print(gridSearchCV.best_params_)
decision_tree = gridSearchCV.best_estimator_
{'max_leaf_nodes': 4}
# 使用測試集對模型進行評估
from sklearn.metrics import accuracy_score
y_prab = gridSearchCV.predict(x_test)
print('accuracy_score:',accuracy_score(y_test,y_prab))
accuracy_score: 0.8455
# 視覺化模型
from sklearn.tree import export_graphviz

export_graphviz(decision_tree,
               out_file='./tree.dot',
               rounded = True,
               filled = True)

生成tree.dot檔案,然後使用dot命令\[dot -Tpng tree.dot -o decisontree_moons.png\]

機器學習之決策樹原理和sklearn實踐

5. 附錄

5.1 sklearn.tree.DecisionTreeClassifier類說明

5.1.1 DecsisionTreeClassifier類引數說明

  • criterion: 特徵選擇方式,string,('gini' or 'entropy'),default='gini'
  • splitter: 每個結點的拆分策略,('best' or 'random'),string,default='best'
  • max_depth: int,default=None
  • min_samples_split: int,float,default=2,分割前所需的最小樣本數
  • min_samples_leaf:
  • min_weight_fraction_leaf:
  • max_features:
  • random_state:
  • max_leaf_nodes:
  • min_impurity_decrease:
  • min_impurity_split:
  • class_weight:
  • presort: bool,default=False,對於小型資料集(幾千個以內)設定presort=True通過對資料預處理來加快訓練,但對於較大訓練集而言,可能會減慢訓練速度

5.1.2 DecisionTreeClassifier屬性說明

  • classes_:
  • feature_importances_:
  • max_features_:
  • n_classes_:
  • n_features_:
  • n_outputs_:
  • tree_:

5.2 GridSearchCV類說明

5.2.1 GridSearchCV引數說明

  • estimator: 估算器,繼承於BaseEstimator
  • param_grid: dict,鍵為引數名,值為該引數需要測試值選項
  • scoring: default=None
  • fit_params:
  • n_jobs: 設定要並行執行的作業數,取值為None或1,None表示1 job,1表示all processors,default=None
  • cv: 交叉驗證的策略數,None或integer,None表示預設3-fold, integer指定“(分層)KFold”中的摺疊數
  • verbose: 輸出日誌型別

5.2.2 GridSearchCV屬性說明

  • cv_results_: dict of numpy(masked) ndarray
  • best_estimator_:
  • best_score_: Mean cross-validated score of the best_estimator
  • best_params_:
  • best_index_: int,The index (of the ``cv_results_`` arrays) which corresponds to the best candidate parameter setting
  • scorer_:
  • n_splits_: The number of cross-validation splits (folds/iterations)
  • refit_time: float

參考資料:

  • (1)
  • (2)
  • (3)李航

相關文章