機器學習之分類迴歸樹(python實現CART)

swensun發表於2018-03-04

之前有文章介紹過決策樹(ID3)。簡單回顧一下:ID3每次選取最佳特徵來分割資料,這個最佳特徵的判斷原則是通過資訊增益來實現的。按照某種特徵切分資料後,該特徵在以後切分資料集時就不再使用,因此存在切分過於迅速的問題。ID3演算法還不能處理連續性特徵。 下面簡單介紹一下其他演算法:

螢幕快照 2018-03-03 14.05.44.png

CART 分類迴歸樹

CART是Classification And Regerssion Trees的縮寫,既能處理分類任務也能做迴歸任務。

image.png
CART樹的典型代表時二叉樹,根據不同的條件將分類。
image.png

CART樹構建演算法 與ID3決策樹的構建方法類似,直接給出CART樹的構建過程。首先與ID3類似採用字典樹的資料結構,包含以下4中元素:

  • 待切分的特徵
  • 待切分的特徵值
  • 右子樹。當不再需要切分的時候,也可以是單個值
  • 左子樹,類似右子樹。

過程如下:

  1. 尋找最合適的分割特徵
  2. 如果不能分割資料集,該資料集作為一個葉子節點。
  3. 對資料集進行二分割
  4. 對分割的資料集1重複1, 2,3 步,建立右子樹。
  5. 對分割的資料集2重複1, 2,3 步,建立左子樹。

明顯的遞迴演算法。

通過資料過濾的方式分割資料集,返回兩個子集。

def splitDatas(rows, value, column):
    # 根據條件分離資料集(splitDatas by value, column)
    # return 2 part(list1, list2)

    list1 = []
    list2 = []

    if isinstance(value, int) or isinstance(value, float):
        for row in rows:
            if row[column] >= value:
                list1.append(row)
            else:
                list2.append(row)
    else:
        for row in rows:
            if row[column] == value:
                list1.append(row)
            else:
                list2.append(row)
    return list1, list2
複製程式碼

劃分資料點

建立二進位制決策樹本質上就是遞迴劃分輸入空間的過程。

image.png

程式碼如下:

# gini()
def gini(rows):
    # 計算gini的值(Calculate GINI)

    length = len(rows)
    results = calculateDiffCount(rows)
    imp = 0.0
    for i in results:
        imp += results[i] / length * results[i] / length
    return 1 - imp
複製程式碼

構建樹

def buildDecisionTree(rows, evaluationFunction=gini):
    # 遞迴建立決策樹, 當gain=0,時停止迴歸
    # build decision tree bu recursive function
    # stop recursive function when gain = 0
    # return tree
    currentGain = evaluationFunction(rows)
    column_lenght = len(rows[0])
    rows_length = len(rows)

    best_gain = 0.0
    best_value = None
    best_set = None

    # choose the best gain
    for col in range(column_lenght - 1):
        col_value_set = set([x[col] for x in rows])
        for value in col_value_set:
            list1, list2 = splitDatas(rows, value, col)
            p = len(list1) / rows_length
            gain = currentGain - p * evaluationFunction(list1) - (1 - p) * evaluationFunction(list2)
            if gain > best_gain:
                best_gain = gain
                best_value = (col, value)
                best_set = (list1, list2)
    dcY = {'impurity': '%.3f' % currentGain, 'sample': '%d' % rows_length}
    #
    # stop or not stop

    if best_gain > 0:
        trueBranch = buildDecisionTree(best_set[0], evaluationFunction)
        falseBranch = buildDecisionTree(best_set[1], evaluationFunction)
        return Tree(col=best_value[0], value = best_value[1], trueBranch = trueBranch, falseBranch=falseBranch, summary=dcY)
    else:
        return Tree(results=calculateDiffCount(rows), summary=dcY, data=rows)
複製程式碼

上面程式碼的功能是先找到資料集切分的最佳位置和分割資料集。之後通過遞迴構建出上面圖片的整棵樹。

剪枝

在決策樹的學習中,有時會造成決策樹分支過多,這是就需要去掉一些分支,降低過度擬合。通過決策樹的複雜度來避免過度擬合的過程稱為剪枝。 後剪枝需要從訓練集生成一棵完整的決策樹,然後自底向上對非葉子節點進行考察。利用測試集判斷是否將該節點對應的子樹替換成葉節點。 程式碼如下:

def prune(tree, miniGain, evaluationFunction=gini):
    # 剪枝 when gain < mini Gain, 合併(merge the trueBranch and falseBranch)
    if tree.trueBranch.results == None:
        prune(tree.trueBranch, miniGain, evaluationFunction)
    if tree.falseBranch.results == None:
        prune(tree.falseBranch, miniGain, evaluationFunction)

    if tree.trueBranch.results != None and tree.falseBranch.results != None:
        len1 = len(tree.trueBranch.data)
        len2 = len(tree.falseBranch.data)
        len3 = len(tree.trueBranch.data + tree.falseBranch.data)

        p = float(len1) / (len1 + len2)

        gain = evaluationFunction(tree.trueBranch.data + tree.falseBranch.data) - p * evaluationFunction(tree.trueBranch.data) - (1 - p) * evaluationFunction(tree.falseBranch.data)

        if gain < miniGain:
            tree.data = tree.trueBranch.data + tree.falseBranch.data
            tree.results = calculateDiffCount(tree.data)
            tree.trueBranch = None
            tree.falseBranch = None
複製程式碼

當節點的gain小於給定的 mini Gain時則合併這兩個節點.。

最後是構建樹的程式碼:

if __name__ == '__main__':
    dataSet = loadCSV()
    decisionTree = buildDecisionTree(dataSet, evaluationFunction=gini)
    prune(decisionTree, 0.4)
    test_data = [5.9,3,4.2,1.5]
    r = classify(test_data, decisionTree)
    print(r)
複製程式碼

可以列印decisionTree可以構建出如如上的圖片中的決策樹。 後面找一組資料測試看能否得到正確的分類。

完整程式碼和資料集請檢視:
github:CART

總結:

  • CART決策樹
  • 分割資料集
  • 遞迴建立樹

參考文章:
CART分類迴歸樹分析與python實現
CART決策樹(Decision Tree)的Python原始碼實現

相關文章