之前有文章介紹過決策樹(ID3)。簡單回顧一下:ID3每次選取最佳特徵來分割資料,這個最佳特徵的判斷原則是通過資訊增益來實現的。按照某種特徵切分資料後,該特徵在以後切分資料集時就不再使用,因此存在切分過於迅速的問題。ID3演算法還不能處理連續性特徵。 下面簡單介紹一下其他演算法:
CART 分類迴歸樹
CART是Classification And Regerssion Trees的縮寫,既能處理分類任務也能做迴歸任務。
CART樹的典型代表時二叉樹,根據不同的條件將分類。CART樹構建演算法 與ID3決策樹的構建方法類似,直接給出CART樹的構建過程。首先與ID3類似採用字典樹的資料結構,包含以下4中元素:
- 待切分的特徵
- 待切分的特徵值
- 右子樹。當不再需要切分的時候,也可以是單個值
- 左子樹,類似右子樹。
過程如下:
- 尋找最合適的分割特徵
- 如果不能分割資料集,該資料集作為一個葉子節點。
- 對資料集進行二分割
- 對分割的資料集1重複1, 2,3 步,建立右子樹。
- 對分割的資料集2重複1, 2,3 步,建立左子樹。
明顯的遞迴演算法。
通過資料過濾的方式分割資料集,返回兩個子集。
def splitDatas(rows, value, column):
# 根據條件分離資料集(splitDatas by value, column)
# return 2 part(list1, list2)
list1 = []
list2 = []
if isinstance(value, int) or isinstance(value, float):
for row in rows:
if row[column] >= value:
list1.append(row)
else:
list2.append(row)
else:
for row in rows:
if row[column] == value:
list1.append(row)
else:
list2.append(row)
return list1, list2
複製程式碼
劃分資料點
建立二進位制決策樹本質上就是遞迴劃分輸入空間的過程。
程式碼如下:
# gini()
def gini(rows):
# 計算gini的值(Calculate GINI)
length = len(rows)
results = calculateDiffCount(rows)
imp = 0.0
for i in results:
imp += results[i] / length * results[i] / length
return 1 - imp
複製程式碼
構建樹
def buildDecisionTree(rows, evaluationFunction=gini):
# 遞迴建立決策樹, 當gain=0,時停止迴歸
# build decision tree bu recursive function
# stop recursive function when gain = 0
# return tree
currentGain = evaluationFunction(rows)
column_lenght = len(rows[0])
rows_length = len(rows)
best_gain = 0.0
best_value = None
best_set = None
# choose the best gain
for col in range(column_lenght - 1):
col_value_set = set([x[col] for x in rows])
for value in col_value_set:
list1, list2 = splitDatas(rows, value, col)
p = len(list1) / rows_length
gain = currentGain - p * evaluationFunction(list1) - (1 - p) * evaluationFunction(list2)
if gain > best_gain:
best_gain = gain
best_value = (col, value)
best_set = (list1, list2)
dcY = {'impurity': '%.3f' % currentGain, 'sample': '%d' % rows_length}
#
# stop or not stop
if best_gain > 0:
trueBranch = buildDecisionTree(best_set[0], evaluationFunction)
falseBranch = buildDecisionTree(best_set[1], evaluationFunction)
return Tree(col=best_value[0], value = best_value[1], trueBranch = trueBranch, falseBranch=falseBranch, summary=dcY)
else:
return Tree(results=calculateDiffCount(rows), summary=dcY, data=rows)
複製程式碼
上面程式碼的功能是先找到資料集切分的最佳位置和分割資料集。之後通過遞迴構建出上面圖片的整棵樹。
剪枝
在決策樹的學習中,有時會造成決策樹分支過多,這是就需要去掉一些分支,降低過度擬合。通過決策樹的複雜度來避免過度擬合的過程稱為剪枝。 後剪枝需要從訓練集生成一棵完整的決策樹,然後自底向上對非葉子節點進行考察。利用測試集判斷是否將該節點對應的子樹替換成葉節點。 程式碼如下:
def prune(tree, miniGain, evaluationFunction=gini):
# 剪枝 when gain < mini Gain, 合併(merge the trueBranch and falseBranch)
if tree.trueBranch.results == None:
prune(tree.trueBranch, miniGain, evaluationFunction)
if tree.falseBranch.results == None:
prune(tree.falseBranch, miniGain, evaluationFunction)
if tree.trueBranch.results != None and tree.falseBranch.results != None:
len1 = len(tree.trueBranch.data)
len2 = len(tree.falseBranch.data)
len3 = len(tree.trueBranch.data + tree.falseBranch.data)
p = float(len1) / (len1 + len2)
gain = evaluationFunction(tree.trueBranch.data + tree.falseBranch.data) - p * evaluationFunction(tree.trueBranch.data) - (1 - p) * evaluationFunction(tree.falseBranch.data)
if gain < miniGain:
tree.data = tree.trueBranch.data + tree.falseBranch.data
tree.results = calculateDiffCount(tree.data)
tree.trueBranch = None
tree.falseBranch = None
複製程式碼
當節點的gain小於給定的 mini Gain時則合併這兩個節點.。
最後是構建樹的程式碼:
if __name__ == '__main__':
dataSet = loadCSV()
decisionTree = buildDecisionTree(dataSet, evaluationFunction=gini)
prune(decisionTree, 0.4)
test_data = [5.9,3,4.2,1.5]
r = classify(test_data, decisionTree)
print(r)
複製程式碼
可以列印decisionTree可以構建出如如上的圖片中的決策樹。 後面找一組資料測試看能否得到正確的分類。
完整程式碼和資料集請檢視:
github:CART
總結:
- CART決策樹
- 分割資料集
- 遞迴建立樹
參考文章:
CART分類迴歸樹分析與python實現
CART決策樹(Decision Tree)的Python原始碼實現