決策樹之CART演算法

acdreamers發表於2015-03-27

在之前介紹過決策樹的ID3演算法實現,今天主要來介紹決策樹的另一種實現,即CART演算法

 

Contents

 

   1. CART演算法的認識

   2. CART演算法的原理

   3. CART演算法的實現

 

 

1. CART演算法的認識

 

   Classification And Regression Tree,即分類迴歸樹演算法,簡稱CART演算法,它是決策樹的一種實現,通

   常決策樹主要有三種實現,分別是ID3演算法,CART演算法和C4.5演算法。

 

   CART演算法是一種二分遞迴分割技術,把當前樣本劃分為兩個子樣本,使得生成的每個非葉子結點都有兩個分支,

   因此CART演算法生成的決策樹是結構簡潔的二叉樹。由於CART演算法構成的是一個二叉樹,它在每一步的決策時只能

   是“是”或者“否”,即使一個feature有多個取值,也是把資料分為兩部分。在CART演算法中主要分為兩個步驟

 

   (1)將樣本遞迴劃分進行建樹過程

   (2)用驗證資料進行剪枝

 

 

2. CART演算法的原理

 

   上面說到了CART演算法分為兩個過程,其中第一個過程進行遞迴建立二叉樹,那麼它是如何進行劃分的 ?

 

   設代表單個樣本的個屬性,表示所屬類別。CART演算法通過遞迴的方式將維的空間劃分為不重

   疊的矩形。劃分步驟大致如下

 

   (1)選一個自變數,再選取的一個值維空間劃分為兩部分,一部分的所有點都滿足

       另一部分的所有點都滿足,對非連續變數來說屬性值的取值只有兩個,即等於該值或不等於該值。

   (2)遞迴處理,將上面得到的兩部分按步驟(1)重新選取一個屬性繼續劃分,直到把整個維空間都劃分完。

 

   在劃分時候有一個問題,它是按照什麼標準來劃分的 ? 對於一個變數屬性來說,它的劃分點是一對連續變數屬

   性值的中點。假設個樣本的集合一個屬性有個連續的值,那麼則會有個分裂點,每個分裂點為相鄰

   兩個連續值的均值。每個屬性的劃分按照能減少的雜質的量來進行排序,而雜質的減少量定義為劃分前的雜質減

   去劃分後的每個節點的雜質量劃分所佔比率之和。而雜質度量方法常用Gini指標,假設一個樣本共有類,那麼

   一個節點的Gini不純度可定義為

 

          

 

   其中表示屬於類的概率,當Gini(A)=0時,所有樣本屬於同類,所有類在節點中以等概率出現時,Gini(A)

   最大化,此時

 

   有了上述理論基礎,實際的遞迴劃分過程是這樣的:如果當前節點的所有樣本都不屬於同一類或者只剩下一個樣

   本,那麼此節點為非葉子節點,所以會嘗試樣本的每個屬性以及每個屬性對應的分裂點,嘗試找到雜質變數最大

   的一個劃分,該屬性劃分的子樹即為最優分支。

 

   下面舉個簡單的例子,如下圖

 

   

 

   在上述圖中,屬性有3個,分別是有房情況,婚姻狀況和年收入,其中有房情況和婚姻狀況是離散的取值,而年

   收入是連續的取值。拖欠貸款者屬於分類的結果。

 

   假設現在來看有房情況這個屬性,那麼按照它劃分後的Gini指數計算如下

 

   

 

   而對於婚姻狀況屬性來說,它的取值有3種,按照每種屬性值分裂後Gini指標計算如下

 

    

 

   最後還有一個取值連續的屬性,年收入,它的取值是連續的,那麼連續的取值採用分裂點進行分裂。如下

 

    

 

   根據這樣的分裂規則CART演算法就能完成建樹過程。

 

   建樹完成後就進行第二步了,即根據驗證資料進行剪枝。在CART樹的建樹過程中,可能存在Overfitting,許多

   分支中反映的是資料中的異常,這樣的決策樹對分類的準確性不高,那麼需要檢測並減去這些不可靠的分支。決策

   樹常用的剪枝有事前剪枝和事後剪枝,CART演算法採用事後剪枝,具體方法為代價複雜性剪枝法。可參考如下鏈

 

   剪枝參考:http://www.cnblogs.com/zhangchaoyang/articles/2709922.html

 

   

3. CART演算法的實現

 

   以下程式碼是網上找的CART演算法的MATLAB實現。

CART
  
function D = CART(train_features, train_targets, params, region)
  
% Classify using classification and regression trees
% Inputs:
% features - Train features
% targets     - Train targets
% params - [Impurity type, Percentage of incorrectly assigned samples at a node]
%                   Impurity can be: Entropy, Variance (or Gini), or Missclassification
% region     - Decision region vector: [-x x -y y number_of_points]
%
% Outputs
% D - Decision sufrace
  
  
[Ni, M]    = size(train_features);
  
%Get parameters
[split_type, inc_node] = process_params(params);
  
%For the decision region
N           = region(5);
mx          = ones(N,1) * linspace (region(1),region(2),N);
my          = linspace (region(3),region(4),N)' * ones(1,N);
flatxy      = [mx(:), my(:)]';
  
%Preprocessing
[f, t, UW, m]   = PCA(train_features, train_targets, Ni, region);
train_features  = UW * (train_features - m*ones(1,M));;
flatxy          = UW * (flatxy - m*ones(1,N^2));;
  
%Build the tree recursively
disp('Building tree')
tree        = make_tree(train_features, train_targets, M, split_type, inc_node, region);
  
%Make the decision region according to the tree
disp('Building decision surface using the tree')
targets = use_tree(flatxy, 1:N^2, tree);
  
D = reshape(targets,N,N);
%END
  
function targets = use_tree(features, indices, tree)
%Classify recursively using a tree
  
if isnumeric(tree.Raction)
   %Reached an end node
   targets = zeros(1,size(features,2));
   targets(indices) = tree.Raction(1);
else
   %Reached a branching, so:
   %Find who goes where
   in_right    = indices(find(eval(tree.Raction)));
   in_left     = indices(find(eval(tree.Laction)));
     
   Ltargets = use_tree(features, in_left, tree.left);
   Rtargets = use_tree(features, in_right, tree.right);
     
   targets = Ltargets + Rtargets;
end
%END use_tree
  
function tree = make_tree(features, targets, Dlength, split_type, inc_node, region)
%Build a tree recursively
  
if (length(unique(targets)) == 1),
   %There is only one type of targets, and this generates a warning, so deal with it separately
   tree.right      = [];
   tree.left       = [];
   tree.Raction    = targets(1);
   tree.Laction    = targets(1);
   break
end
  
[Ni, M] = size(features);
Nt      = unique(targets);
N       = hist(targets, Nt);
  
if ((sum(N < Dlength*inc_node) == length(Nt) - 1) | (M == 1)),
   %No further splitting is neccessary
   tree.right      = [];
   tree.left       = [];
   if (length(Nt) ~= 1),
      MLlabel   = find(N == max(N));
   else
      MLlabel   = 1;
   end
   tree.Raction    = Nt(MLlabel);
   tree.Laction    = Nt(MLlabel);
     
else
   %Split the node according to the splitting criterion
   deltaI = zeros(1,Ni);
   split_point = zeros(1,Ni);
   op = optimset('Display', 'off'); 
   for i = 1:Ni,
      split_point(i) = fminbnd('CARTfunctions', region(i*2-1), region(i*2), op, features, targets, i, split_type);
      I(i) = feval('CARTfunctions', split_point(i), features, targets, i, split_type);
   end
     
   [m, dim] = min(I);
   loc = split_point(dim);
      
   %So, the split is to be on dimention 'dim' at location 'loc'
   indices = 1:M;
   tree.Raction= ['features(' num2str(dim) ',indices) >  ' num2str(loc)];
   tree.Laction= ['features(' num2str(dim) ',indices) <= ' num2str(loc)];
   in_right    = find(eval(tree.Raction));
   in_left     = find(eval(tree.Laction));
     
   if isempty(in_right) | isempty(in_left)
      %No possible split found
   tree.right      = [];
   tree.left       = [];
   if (length(Nt) ~= 1),
      MLlabel   = find(N == max(N));
   else
      MLlabel = 1;
   end
   tree.Raction    = Nt(MLlabel);
   tree.Laction    = Nt(MLlabel);
   else
   %...It's possible to build new nodes
   tree.right = make_tree(features(:,in_right), targets(in_right), Dlength, split_type, inc_node, region);
   tree.left  = make_tree(features(:,in_left), targets(in_left), Dlength, split_type, inc_node, region);    
   end
     
end


在Julia中的決策樹包:https://github.com/bensadeghi/DecisionTree.jl/blob/master/README.md

 

相關文章