機器學習:如何構造一個簡單的決策樹

weixin_33895657發表於2017-05-10

資料集下載地址:http://archive.ics.uci.edu/ml/datasets/Adult

在寫該演算法時遇到一個問題:
構造決策樹時,這兩段程式碼雖然都可以成功執行,但是構造的結果卻有些不同。
如果用第一種方式遍歷每個分支,會導致每次從右側分支開始遍歷,即使把branc_dict調整為{'right':right_split,'left':left_split}
而使用第二種方式,則可以正常遍歷(先遍歷左分支,再遍歷右分支),到目前為止還沒發現是什麼原因導致的,各位有知道的歡迎留言~

3601125-0a1a64a81cf4f5e5.png

以下為程式碼過程:

讀入資料

import pandas as pd
columns=['age', 'workclass', 'fnlwgt', 'education', 'education_num',
       'marital_status', 'occupation', 'relationship', 'race', 'sex',
       'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',
       'high_income']
data=pd.read_table('./data/income.data',delimiter=',',names=columns)
data.head()

在開始構建決策樹之前,我們需要把資料集中的分型別資料轉換為數值型,pandas.Categorical方法可以把string型分類的column轉換為Categorical Type,轉換以後系統就會自動將該column中的類別對映為一個數字。

list=['workclass','education','marital_status', 'occupation',
           'relationship', 'race', 'sex', 'native_country','high_income']
for name in list:
    col=pd.Categorical.from_array(data[name])
    data[name]=col.codes
data.head()

計算熵和資訊增益:

3601125-977fa11846cd06bb.png

3601125-c11aad313d14b0ad.png

資訊增益

3601125-a5ce38adb7eb984c.png
3601125-c896afe523d18595.png
def calc_entropy(target):
    counts=np.bincount(target)
    probabilities=counts/len(target)
    entropys=probabilities*np.log2(probabilities)
    return -sum(entropys)
def calc_information_gain(data,split_name,target):
    entropy=calc_entropy(data[target])
    median=np.median(data[split_name])
    left_split=data[data[split_name]<=median]
    right_split=data[data[split_name]>median]
    
    to_subtract=0
    for subset in [left_split,right_split]:
        prob=subset.shape[0]/data.shape[0]
        to_subtract+=prob*calc_entropy(subset[target])
    return entropy-to_subtract
#通過計算每一個column的資訊增益,獲得最佳分裂屬性(資訊增益最大的)
def find_best_column(data,columns,target):
    information_gains=[]
    for name in columns:
        information_gains.append(calc_information_gain(data,name,'high_income'))
    information_index=information_gains.index(max(information_gains))
    best_column=columns[information_index]
    return best_column
帶有儲存功能的ID3演算法:

為了實現儲存功能,可以使用一個含有以下關鍵字的dictionary儲存節點:

  • left/right 關鍵字表示左右結點
  • column 最佳分裂屬性
  • median 分裂屬性的中位數
  • number 結點編號
  • label
    如果結點為葉節點,則僅僅含有label(值為0/1)和number關鍵字
    虛擬碼如下:
def id3(data, target, columns, tree)
    1 Create a node for the tree
    2 Number the node
    3 If all of the values of the target attribute are 1, assign 1 to the label key in tree
    4 If all of the values of the target attribute are 0, assign 0 to the label key in tree
    5 Using information gain, find A, the column that splits the data best
    6 Find the median value in column A
    7 Assign the column and median keys in tree
    8 Split A into values less than or equal to the median (0), and values above the median (1)
    9 For each possible value (0 or 1), vi, of A,
   10    Add a new tree branch below Root that corresponds to rows of data where A = vi
   11    Let Examples(vi) be the subset of examples that have the value vi for A
   12    Create a new key with the name corresponding to the side of the split (0=left, 1=right).  The value of this key should be an empty dictionary.
   13    Below this new branch, add the subtree id3(data[A==vi], target, columns, tree[split_side])
   14 Return Root

實現程式碼:

tree={}
nodes=[]  #重點注意:因為在遞迴中使用int型不能自增,所以採取使用陣列的方法。
def id3(data,columns,target,tree):

    nodes.append(len(nodes)+1)
    tree['number']=nodes[-1]
    unique_targets=pd.unique(data[target])
    if len(unique_targets)==1:
        tree['label']=unique_targets[0]
        return  #不要忘記返回
    
    #如unique長度不為1,既包含0又含1,需要分裂:
    best_column=find_best_column(data,columns,target)
    median=np.median(data[best_column])
    tree['column']=best_column  #分裂key
    tree['median']=median  #median key
    
    left_split=data[data[best_column]<=median]
    right_split=data[data[best_column]>median]
    branch_dict={'left':left_split,'right':right_split}
    for branch in branch_dict:
        tree[branch]={}
        id3(branch_dict[branch],columns,target,tree[branch])

id3(data,  ["age", "marital_status"],"high_income", tree)
print(tree)

結果為

3601125-80d4202f730d3134.png

為了方便觀察決策樹的構造結果,我們可以寫一個結點輸出函式,結構化的輸出生成的決策樹:

def print_with_depth(string,depth):
    prefix="    "*depth
    print("{0}{1}".format(prefix,string))
def print_node(tree,depth):
    if 'label' in tree:
        print_with_depth('Leaf label{0}'.format(tree['label']),depth)
        return
    print_with_depth('{0}>{1}'.format(tree['column'],tree['median']),depth)
    
    branches = [tree["left"], tree["right"]]
        
    for branch in branches:
        print_node(branch,depth+1)
print_node(tree, 0 )

輸出


3601125-7e97784d02002afa.png

實現預測功能:

#預測函式
def predict(tree,row):
    if 'label' in tree:
        return tree['label']
    column=tree['column']
    median=tree['median']
    
    if row['columns']<=median:
        return predict(tree['left'],row)
    else:
        return predict(tree['right'],row)
print(predict(tree, data.iloc[0]))

predictions=data.apply(lambda x:predict(tree,x),axis=1)

完。

相關文章