Decision Tree
①Aggregation Model
回顧上一篇文章講到的聚合模型,三個臭皮匠頂一個諸葛亮。於是出現了blending,bagging,boost,stacking。blending有uniform和non-uniform,stacking是屬於條件類的,而boost裡面的Adaboost是邊學習邊做linear,bagging也是屬於邊學習邊做uniform的。Decision Tree就是屬於邊做學習然後按照條件分的一種。如下圖,aggregation model就是是補全了:
②Decision Tree Hypothesis
決策樹是一種很傳統的演算法,出現的很早,例如下面,按照下班時間,是否約會,提交截止時間進行判斷,和人的處理方式很像:
上面的菱形就像是很簡單的分割平面,而箭頭就是判斷過程,其實就是學習過程,最後的Y和N就是分出來的結果。可以對應到下面的式子:
最後那些小小的Y,N就是g(x),和之前的SVM他們都不太一樣,這裡的g(x)通常就是一個常數了,也叫base hypothesis;箭頭就是q(x)判斷條件,紅色就是找到了最好split method的地方。
從另一個方面來看決策樹:
和上面理解是一樣的。
Strengths and Weaknesses
優點:
模型直觀,便於理解,應用很廣泛
簡單,容易實現。
訓練和預測的時候,時間短預測準確率高
缺點
缺少足夠的理論支援,後面給出的VC dimension沒有什麼太完備的道理。
對於找到合適的樹要花額外的時間。
決策樹代表性的演演算法比較少
③Decision Tree Algorithm
根據上面的公式,基本演算法:
按照決策樹執行流程,可以分成四個部分:
首先學習設定劃分不同分支的標準和條件是什麼;接著將整體資料集D根據分支個數C和條件,劃為不同分支下的子集Dc;然後對每個分支下的Dc進行訓練,得到相應的機器學習模型Gc;最後將所有分支下的Gc合併到一起,組成大矩G(x)。但值得注意的是,這種遞迴的形式需要終止條件,否則程式將一直進行下去。當滿足遞迴的終止條件之後,將會返回基本的hypothesis gt(x)。
所以,包含了四個基本演算法選擇:
分支個數
分支條件
終止條件
基本演算法
常用決策樹演算法模型——CART
CART演算法對決策樹演算法追加了一些限制:
①c = 2,分支的個數要等於2,和二叉樹有點想。
②本著g(x)simplify的原則,g(x)規定他就是一個常數,也就是類別。
③按照Ein最小化的原則每一次選擇condition。
其實決策樹的分類有點像Adaboost的stump分類。但是Adaboost的stump僅僅是按照準確率來了,而decision tree的標準是purity,純淨度。意思就是熵了。purifying的核心思想就是每次切割都儘可能讓左子樹和右子樹中同類樣本佔得比例最大或者yn都很接近(regression),即錯誤率最小。比如說classifiacation問題中,如果左子樹全是正樣本,右子樹全是負樣本,那麼它的純淨度就很大,說明該分支效果很好。
所以主要問題就變成了如何尋找純淨度最好的問題了。
④purifying
純淨度其實就是熵了。熵是代表混亂程度的。幾個比較常見的演算法:ID3,ID4.5,gini係數。
ID3
以資訊理論為基礎,以資訊熵和資訊增益為衡量標準,從而實現對資料的歸納分類。
資訊增益,就是指split前後熵的變化,選擇最好的一個,也就是說由於使用這個屬性分割樣例而導致的期望熵降低。資訊增益就是原有資訊熵與屬性劃分後資訊熵(需要對劃分後的資訊熵取期望值)的差值。
但是他的缺點也很明顯:
1.沒有剪枝過程,為了去除過渡資料匹配的問題,可通過裁剪合併相鄰的無法產生大量資訊增益的葉子節點。因為選擇的已經是最好的了,如果合併了肯定不夠之前的好。
2.資訊增益的方法偏向選擇具有大量值的屬性,也就是說某個屬性特徵索取的不同值越多,那麼越有可能作為分裂屬性,這樣是不合理的。比如前面的ID編號,1/N再來個log很小的。
3.只可以處理離散分佈的資料特徵。這個很明顯了,如果是連續型資料,很難分的。
基於以上缺點又改進了一下。
ID4.5
改進就是ID4.5了,這個就不是資訊增益了,是資訊增益率。
資訊增益率是資訊增益與資訊熵的比例
這樣的改進其實就是使得離散化可以連續化而已,二分就好了。
優點:
1.面對資料遺漏和輸入欄位很多的問題時非常穩健。
2.通常不需要很長的訓練次數進行估計。工作原理是基於產生最大資訊增益的欄位逐級分割樣本。
3.比一些其他型別的模型易於理解,模型推出的規則有非常直觀的解釋。
4.允許進行多次多於兩個子組的分割。目標欄位必須為分類欄位。
CART
Cart演算法裡面用的是gini係數,但是還是有必要說一下decision tree做擬合的時候Ein要怎麼optimal。
regression
對於regression問題,首先想到的肯定是均方差了:
y杆就是yn的平均。
classification
對於分類:
y表示類別最多的。
以上都是借鑑前面algorithm的思想推導的,現在回到純度。想要purity最小,那麼就是y要多了,最好全部都是了,所以classification error:
上面的只是考慮了分支最大的,我們需要把所有的都考慮進去,於是:
gini係數就出來了:
可以看到gini係數和熵差不了多少,一定程度上可以代表熵。
對於CART的Teminal condition,自然就是兩個條件:1.首先是yn只有一個種類,分不了了。2.其次就是Xn都是一樣的不能再分。
⑤Decision Tree Heuristics in CART
基本流程:
可以看到CART演算法在處理binary classification和regression問題時非常簡單實用,而且,處理muti-class classification問題也十分容易。
但是要注意一個問題,既然有錯誤就分,那麼到最後肯定是一個二分完全樹,Ein一定是0,這樣是有過擬合的。對於overfit,要引入的就是過擬合:
既然是過擬合了,這棵樹不要這麼大就行了,於是進行修剪,pruning,剪枝操作。比如,總共是10片葉子,我們取掉1片,剩下9片,9種情況,我們比較這9種情況哪種好。
這裡其實就是剛剛說的decision tree理論不是特別的完善,事實上NumberOfLeaves ≈ Ω其實我們在實踐中得到的。因為葉子越多複雜度越大。所以就直接把葉子數量當做是複雜度Ω了。
在決策樹中預測中,還會遇到一種問題,就是當某些特徵缺失的時候,沒有辦法進行切割和分支選擇。一種常用的方法就是surrogate branch,即尋找與該特徵相似的替代feature。如何確定是相似的feature呢?做法是在決策樹訓練的時候,找出與該特徵相似的feature,如果替代的feature與原feature切割的方式和結果是類似的,那麼就表明二者是相似的,就把該替代的feature也儲存下來。當預測時遇到原feature缺失的情況,就用替代feature進行分支判斷和選擇。
⑥Decision Tree in action
貌似和Adaboost很像啊!
最後在總結一下:
⑦程式碼實現Decision Tree
包括建立樹,預測,視覺化樹,這篇東西內容不多,程式碼講解多。
首先引入一個計算gini係數:
def cal_gini(data):
'''calculate the gini index
input:data(list)
output:gini(float)
'''
total_sample = len(data)
if total_sample == 0:
return 0
label_count = label_uniqueness(data)
gini = 0
for label in label_count:
gini = gini + pow(label_count[label] , 2)
gini = 1 - float(gini) / pow(total_sample , 2)
return gini
pass
傳進的是一個list,計算這個list裡面label數量,然後統計gini係數返回。
還有一個分別計算類別數量的函式,剛剛的gini係數用到的:
def label_uniqueness(data):
'''Counting the number of defferent labels in the dataset
input:dataset
output:Number of labels
'''
label_uniq = {}
for x in data:
label = x[len(x) - 1]
if label not in label_uniq:
label_uniq[label] = 0
label_uniq[label] += 1
return label_uniq
pass
這個就是tool檔案裡面的。
建立節點node:
class node:
'''Tree node
'''
def __init__(self , fea = -1, value = None, results = None, right = None, left = None):
'''
initialization function
:param fea:column index value
:param value:split value
:param results:The class belongs to
:param right:right side
:param left:left side
'''
self.fea = fea
self.value = value
self.results = results
self.right = right
self.left = left
pass
fea就是當前分割的維度,value就是分割的值,result就是label,right右子樹,left左子樹。
接下來就是主要建立樹的類了:
class decision_tree(object):
def build_tree(self,data):
'''Create decision tree
input:data
output:root
'''
if len(data) == 0:
return node()
currentGini = tool.cal_gini(data)
bestGain = 0.0
bestCriterria = None # store the optimal cutting point
bestSets = None # store two datasets which have been splited
feature_num = len(data[0]) - 1 # Number of features
for fea in range(0 , feature_num):
feature_values = {}
for sample in data:
feature_values[sample[fea]] = 1 # store the value in the demension fea possibly
for value in feature_values.keys():
(set_first, set_second) = self.split_tree(data, fea, value)
nowGini = float(len(set_first) * tool.cal_gini(set_first) + len(set_second) * tool.cal_gini(set_second)) / len(data)
gain = currentGini - nowGini
if gain > bestGain and len(set_first) > 0 and len(set_second) > 0:
bestGain = gain
bestCriterria = (fea , value)
bestSets = (set_first , set_second)
pass
if bestGain > 0:
right = self.build_tree(bestSets[0])
left = self.build_tree(bestSets[1])
return node(fea = bestCriterria[0], value = bestCriterria[1], right = right, left = left)
else:
return node(results=tool.label_uniqueness(data))
def split_tree(self , data , fea , value):
'''split the dataset according demension and value
input:data
output:two data
'''
set_first = []
set_second = []
for x in data:
if x[fea] >= value:
set_first.append(x)
else:
set_second.append(x)
return (set_first, set_second)
pass
def predict(self, sample, tree):
'''prediction
input:sample, the tree which we have been built
output:label
'''
if tree.results != None:
return tree.results
else:
val_sample = sample[tree.fea]
branch = None
if val_sample >= tree.value:
branch = tree.right
else:
branch = tree.left
return self.predict(sample, branch)
def predcit_samples(self, samples, tree):
predictions = []
for sample in samples:
predictions.append(self.predict(sample, tree))
return predictions
pass
其實很簡單,就是按照feature和value分類。忘了這個是前向還是後向了,我是看那個二叉樹跟著搞的,大一的時候學過,過了半年差不多忘光了。
看看預測效果吧!
使用的資料還是iris資料集,視覺化還得降維,麻煩,於是就是視覺化樹了,發現更麻煩:
if __name__ == '__main__':
print('load_data......')
dataSet = load_data()
data = dataSet.data
target = dataSet.target
dataframe = pd.DataFrame(data = data, dtype = np.float32)
dataframe.insert(4, 'label', target)
dataMat = np.mat(dataframe)
'''test and train
'''
X_train, X_test, y_train, y_test = train_test_split(dataMat[:, 0:-1], dataMat[:, -1], test_size=0.3, random_state=0)
data_train = np.hstack((X_train, y_train))
data_train = data_train.tolist()
X_test = X_test.tolist()
tree = decisionTree.decision_tree()
tree_root = tree.build_tree(data_train)
predictions = tree.predcit_samples(X_test, tree_root)
pres = []
for i in predictions:
pres.append(list(i.keys()))
y_test = y_test.tolist()
accuracy = 0
for i in range(len(y_test)):
if y_test[i] == pres[i]:
accuracy += 1
print('Accuracy : ', accuracy / len(y_test))
準確率還是蠻高的。
首先要求樹的葉子數:
一樣是遞迴。
def getNumLeafs(myTree):
if myTree == None:
return 0
elif myTree.right == None and myTree.left == None:
return 1
else:
return getNumLeafs(myTree.right) + getNumLeafs(myTree.left)
然後是求深度:
def getDepth(myTree):
if myTree == None:
return 0
right = getDepth(myTree.right)
left = getDepth(myTree.left)
return max(right+1, left+1)
之後就是畫節點了,求深度和葉子數只是想著可以按照深度把樹畫的分開點。
還有一個裝parent節點座標的:
class TreeNode(object):
def __init__(self, x, y, parentX = None, parentY = None):
self.x = x
self.y = y
self.parentX = parentX
self.parentY = parentY
pass
最後就是主要的畫圖了:
def drawNode(x, y ,parent,color, marker, myTree, position):
if myTree.results == None or len(list(myTree.results.keys())) > 1:
plt.scatter(x, y, c=color, marker=marker, s=200)
if myTree.right == None and myTree.left == None:
results = list(myTree.results.keys())
plt.annotate(s = 'label == ' + str(results[0]), xy=(x - 15, y))
if results[0] == 0.0:
plt.annotate(s='label == 0.0', xy=(x , y))
plt.scatter(x, y, c='orange', marker='H', s=100)
if results[0] == 1.0:
plt.scatter(x, y, c='pink', marker='8', s=100)
if results[0] == 2.0:
plt.scatter(x, y, c='r', marker='+', s=100)
if myTree.value != None and myTree.fea != None:
po = 5
if position == 'right':
plt.annotate(s = 'dimension' + str(myTree.fea) + '>' + str(round(myTree.value, 2)), xy = (x-25 - po, y))
else:
plt.annotate(s='dimension' + str(myTree.fea) + '>' + str(round(myTree.value, 2)), xy=(x - 25 + po, y))
if parent != None:
plt.plot([x, parent.x], [y, parent.y], color = 'gray', alpha = 0.5)
def draw(myTree, parent = None, x = 100, y = 100, color = 'r', marker = '^', position = None):
NumberLeaf = getNumLeafs(myTree)
Depth = getDepth(myTree)
delta = (NumberLeaf+Depth)
drawNode(x, y, parent, color, marker, myTree,position)
if myTree.right != None:
draw(myTree.right, parent=TreeNode(x, y) ,x=x+5*delta, y=y-5-delta,color='b', marker='x', position='right')
if myTree.left != None:
draw(myTree.left,parent=TreeNode(x, y) ,x=x-5*delta, y=y-2-delta, color='g', marker='o', position='left')
pass
加上這句 plt.annotate(s='label == 0.0', xy=(x , y))是因為那個註釋死活畫不出來,應該是擋住了。主要還是draw函式,drawNode只是畫而已,判斷都是為了加註釋的,來看看效果圖:
如果當時學資料結構用的是python多好!
所有程式碼在GitHub上:
https://github.com/GreenArrow2017/MachineLearning/tree/master/MachineLearning/DecisionTree
相關文章
- Decision tree——決策樹
- Machine Learning (10) - Decision TreeMac
- 決策樹(Decision Tree)
- 大資料————決策樹(decision tree)大資料
- 分類演算法-決策樹 Decision Tree演算法
- Machine Learning (11) - 關於 Decision Tree 的小練習Mac
- 機器學習之 決策樹(Decision Tree)python實現機器學習Python
- 人工智慧之機器學習基礎——決策樹(Decision Tree)人工智慧機器學習
- 林軒田機器學習技法課程學習筆記9 — Decision Tree機器學習筆記
- 林軒田機器學習技法課程學習筆記11 — Gradient Boosted Decision Tree機器學習筆記
- 機器學習演算法系列(十七)-決策樹學習演算法(Decision Tree Learning Algorithm)機器學習演算法Go
- SAP QM Auto Usage Decision
- MSDS 490: Healthcare Analytics and Decision Making
- 決策支援系統(Decision Support System,DSS)
- tree
- SAP UI5 Decision Table 的特性介紹UI
- Tree Compass
- A - Distance in Tree
- DSU on Tree
- Rebuild TreeRebuild
- 01 Tree
- 【MySQL(1)| B-tree和B+tree】MySql
- 多路查詢樹:B-tree/b+tree
- LeetCode#110.Balanced Binary Tree(Tree/Height/DFS/Recursion)LeetCode
- Root of AVL Tree
- Tree – Information TheoryORM
- mvn dependency:tree
- Traversals of binary tree
- Circular Spanning Tree
- B-tree
- B+tree
- segment tree beats
- tree-shaking
- Walking the File Tree
- Causal Inference理論學習篇-Tree Based-Causal Tree
- LeetCode C++ 968. Binary Tree Cameras【Tree/DFS】困難LeetCodeC++
- Trie tree實踐
- 100-Same Tree