決策樹之CART演算法
在之前介紹過決策樹的ID3演算法實現,今天主要來介紹決策樹的另一種實現,即CART演算法。
Contents
1. CART演算法的認識
2. CART演算法的原理
3. CART演算法的實現
1. CART演算法的認識
Classification And Regression Tree,即分類迴歸樹演算法,簡稱CART演算法,它是決策樹的一種實現,通
常決策樹主要有三種實現,分別是ID3演算法,CART演算法和C4.5演算法。
CART演算法是一種二分遞迴分割技術,把當前樣本劃分為兩個子樣本,使得生成的每個非葉子結點都有兩個分支,
因此CART演算法生成的決策樹是結構簡潔的二叉樹。由於CART演算法構成的是一個二叉樹,它在每一步的決策時只能
是“是”或者“否”,即使一個feature有多個取值,也是把資料分為兩部分。在CART演算法中主要分為兩個步驟
(1)將樣本遞迴劃分進行建樹過程
(2)用驗證資料進行剪枝
2. CART演算法的原理
上面說到了CART演算法分為兩個過程,其中第一個過程進行遞迴建立二叉樹,那麼它是如何進行劃分的 ?
設代表單個樣本的個屬性,表示所屬類別。CART演算法通過遞迴的方式將維的空間劃分為不重
疊的矩形。劃分步驟大致如下
(1)選一個自變數,再選取的一個值,把維空間劃分為兩部分,一部分的所有點都滿足,
另一部分的所有點都滿足,對非連續變數來說屬性值的取值只有兩個,即等於該值或不等於該值。
(2)遞迴處理,將上面得到的兩部分按步驟(1)重新選取一個屬性繼續劃分,直到把整個維空間都劃分完。
在劃分時候有一個問題,它是按照什麼標準來劃分的 ? 對於一個變數屬性來說,它的劃分點是一對連續變數屬
性值的中點。假設個樣本的集合一個屬性有個連續的值,那麼則會有個分裂點,每個分裂點為相鄰
兩個連續值的均值。每個屬性的劃分按照能減少的雜質的量來進行排序,而雜質的減少量定義為劃分前的雜質減
去劃分後的每個節點的雜質量劃分所佔比率之和。而雜質度量方法常用Gini指標,假設一個樣本共有類,那麼
一個節點的Gini不純度可定義為
其中表示屬於類的概率,當Gini(A)=0時,所有樣本屬於同類,所有類在節點中以等概率出現時,Gini(A)
最大化,此時。
有了上述理論基礎,實際的遞迴劃分過程是這樣的:如果當前節點的所有樣本都不屬於同一類或者只剩下一個樣
本,那麼此節點為非葉子節點,所以會嘗試樣本的每個屬性以及每個屬性對應的分裂點,嘗試找到雜質變數最大
的一個劃分,該屬性劃分的子樹即為最優分支。
下面舉個簡單的例子,如下圖
在上述圖中,屬性有3個,分別是有房情況,婚姻狀況和年收入,其中有房情況和婚姻狀況是離散的取值,而年
收入是連續的取值。拖欠貸款者屬於分類的結果。
假設現在來看有房情況這個屬性,那麼按照它劃分後的Gini指數計算如下
而對於婚姻狀況屬性來說,它的取值有3種,按照每種屬性值分裂後Gini指標計算如下
最後還有一個取值連續的屬性,年收入,它的取值是連續的,那麼連續的取值採用分裂點進行分裂。如下
根據這樣的分裂規則CART演算法就能完成建樹過程。
建樹完成後就進行第二步了,即根據驗證資料進行剪枝。在CART樹的建樹過程中,可能存在Overfitting,許多
分支中反映的是資料中的異常,這樣的決策樹對分類的準確性不高,那麼需要檢測並減去這些不可靠的分支。決策
樹常用的剪枝有事前剪枝和事後剪枝,CART演算法採用事後剪枝,具體方法為代價複雜性剪枝法。可參考如下鏈接
剪枝參考:http://www.cnblogs.com/zhangchaoyang/articles/2709922.html
3. CART演算法的實現
以下程式碼是網上找的CART演算法的MATLAB實現。
CART
function D = CART(train_features, train_targets, params, region)
% Classify using classification and regression trees
% Inputs:
% features - Train features
% targets - Train targets
% params - [Impurity type, Percentage of incorrectly assigned samples at a node]
% Impurity can be: Entropy, Variance (or Gini), or Missclassification
% region - Decision region vector: [-x x -y y number_of_points]
%
% Outputs
% D - Decision sufrace
[Ni, M] = size(train_features);
%Get parameters
[split_type, inc_node] = process_params(params);
%For the decision region
N = region(5);
mx = ones(N,1) * linspace (region(1),region(2),N);
my = linspace (region(3),region(4),N)' * ones(1,N);
flatxy = [mx(:), my(:)]';
%Preprocessing
[f, t, UW, m] = PCA(train_features, train_targets, Ni, region);
train_features = UW * (train_features - m*ones(1,M));;
flatxy = UW * (flatxy - m*ones(1,N^2));;
%Build the tree recursively
disp('Building tree')
tree = make_tree(train_features, train_targets, M, split_type, inc_node, region);
%Make the decision region according to the tree
disp('Building decision surface using the tree')
targets = use_tree(flatxy, 1:N^2, tree);
D = reshape(targets,N,N);
%END
function targets = use_tree(features, indices, tree)
%Classify recursively using a tree
if isnumeric(tree.Raction)
%Reached an end node
targets = zeros(1,size(features,2));
targets(indices) = tree.Raction(1);
else
%Reached a branching, so:
%Find who goes where
in_right = indices(find(eval(tree.Raction)));
in_left = indices(find(eval(tree.Laction)));
Ltargets = use_tree(features, in_left, tree.left);
Rtargets = use_tree(features, in_right, tree.right);
targets = Ltargets + Rtargets;
end
%END use_tree
function tree = make_tree(features, targets, Dlength, split_type, inc_node, region)
%Build a tree recursively
if (length(unique(targets)) == 1),
%There is only one type of targets, and this generates a warning, so deal with it separately
tree.right = [];
tree.left = [];
tree.Raction = targets(1);
tree.Laction = targets(1);
break
end
[Ni, M] = size(features);
Nt = unique(targets);
N = hist(targets, Nt);
if ((sum(N < Dlength*inc_node) == length(Nt) - 1) | (M == 1)),
%No further splitting is neccessary
tree.right = [];
tree.left = [];
if (length(Nt) ~= 1),
MLlabel = find(N == max(N));
else
MLlabel = 1;
end
tree.Raction = Nt(MLlabel);
tree.Laction = Nt(MLlabel);
else
%Split the node according to the splitting criterion
deltaI = zeros(1,Ni);
split_point = zeros(1,Ni);
op = optimset('Display', 'off');
for i = 1:Ni,
split_point(i) = fminbnd('CARTfunctions', region(i*2-1), region(i*2), op, features, targets, i, split_type);
I(i) = feval('CARTfunctions', split_point(i), features, targets, i, split_type);
end
[m, dim] = min(I);
loc = split_point(dim);
%So, the split is to be on dimention 'dim' at location 'loc'
indices = 1:M;
tree.Raction= ['features(' num2str(dim) ',indices) > ' num2str(loc)];
tree.Laction= ['features(' num2str(dim) ',indices) <= ' num2str(loc)];
in_right = find(eval(tree.Raction));
in_left = find(eval(tree.Laction));
if isempty(in_right) | isempty(in_left)
%No possible split found
tree.right = [];
tree.left = [];
if (length(Nt) ~= 1),
MLlabel = find(N == max(N));
else
MLlabel = 1;
end
tree.Raction = Nt(MLlabel);
tree.Laction = Nt(MLlabel);
else
%...It's possible to build new nodes
tree.right = make_tree(features(:,in_right), targets(in_right), Dlength, split_type, inc_node, region);
tree.left = make_tree(features(:,in_left), targets(in_left), Dlength, split_type, inc_node, region);
end
end
在Julia中的決策樹包:https://github.com/bensadeghi/DecisionTree.jl/blob/master/README.md
相關文章
- 決策樹模型(4)Cart演算法模型演算法
- ML《決策樹(三)CART》
- 【面試考】【入門】決策樹演算法ID3,C4.5和CART面試演算法
- 決策樹演算法演算法
- 《機器學習Python實現_09_02_決策樹_CART》機器學習Python
- 機器學習之決策樹演算法機器學習演算法
- 機器學習經典演算法之決策樹機器學習演算法
- 決策樹演算法-實戰篇演算法
- 決策樹演算法-理論篇演算法
- 決策樹
- 機器學習之決策樹機器學習
- 分類演算法-決策樹 Decision Tree演算法
- 決策樹模型(3)決策樹的生成與剪枝模型
- 決策樹示例
- Reinventing the wheel:決策樹演算法的實現演算法
- 決策樹演算法的推理與實現演算法
- 4. 決策樹
- Decision tree——決策樹
- 決策樹(Decision Tree)
- 演算法金 | 突破最強演算法模型,決策樹演算法!!演算法模型
- Python機器學習:決策樹001什麼是決策樹Python機器學習
- 遊戲AI之決策結構—行為樹遊戲AI
- 通俗地說決策樹演算法(二)例項解析演算法
- 最常用的決策樹演算法!Random Forest、Adaboost、GBDT 演算法演算法randomREST
- Cart迴歸樹、GBDT、XGBoost
- 分類——決策樹模型模型
- 機器學習:決策樹機器學習
- 關於決策樹的理解
- 決策樹學習總結
- 決策樹和隨機森林隨機森林
- 機器學習之 決策樹(Decision Tree)python實現機器學習Python
- 機器學習之使用sklearn構造決策樹模型機器學習模型
- 機器學習之決策樹原理和sklearn實踐機器學習
- 通俗地說決策樹演算法(一)基礎概念介紹演算法
- 機器學習之-決策樹演算法【人工智慧工程師--AI轉型必修課】機器學習演算法人工智慧工程師AI
- 通俗易懂--決策樹演算法、隨機森林演算法講解(演算法+案例)演算法隨機森林
- 演算法金 | 決策樹、隨機森林、bagging、boosting、Adaboost、GBDT、XGBoost 演算法大全演算法隨機森林
- 機器學習之決策樹在sklearn中的實現機器學習
- 機器學習之決策樹ID3(python實現)機器學習Python