機器學習-樹迴歸

塵世中迷途小碼農發表於2020-12-25

在前面文章機器學習-迴歸中,我們討論了一般的線性迴歸,這裡面有一些強大的方法,而且也非常實用。但這些方法有一些不足,

  1. 需要擬合所有的樣本點(區域性加權線性迴歸除外),計算量較大
  2. 現實生活中很多問題是非線性的,不能使用線性模型

這篇文章會介紹一種非線性迴歸模型-樹迴歸,通過CART(Claqssification And Regression Tree,分類迴歸樹)來構建模型演算法。該演算法可用於分類(找出類別),也可用於迴歸(預測)。CART有可能出現過擬合的問題,通過樹剪枝技術,可以解決此問題。

本章包含以下內容:

  • 迴歸樹
  • 樹剪枝
  • 模型樹
  • 分類和預測
  • 小結

部分內容引用自《Machine Learning in Action》


迴歸樹

迴歸樹一般用於資料的分類,其基本思想是將原資料集按照某些屬性值進行二拆分,有點類似構建二叉查詢樹的過程。樹的葉子節點可能是一個資料點的值,也可能是一些資料點的平均值,這個可通過引數控制。

機器學習-決策樹一文中,我們通過資訊增益來尋找最佳拆分。而這裡,我們通過最小平方和的方式來尋找最佳拆分。例如,

假設原資料集為S,按照某個屬性值a可以將S拆分為S1S2,然後計算SS1S2的樣本方差,如果滿足:

var(S1) * N(S1) + var(S1) * N(S2) < var(S) * N(S),其中,var(S1)為S1的樣本方差,N(S1)表示樣本點個數,S2和S類似,

則說明按照屬性a拆分能夠降低原資料集S的離散程度。對S中的所有屬性值進行該計算,我們可以找到一個最好的a,該屬性值就是欲拆分的值。

注意:方差乘以樣本點個數,就是樣本與期望的差的平方和。

下面通過程式碼來實現迴歸樹。

建立模組reg_tree.py,並輸入以下程式碼:

import numpy as np


def load_data_set(file_name):
    data_mat = []
    with open(file_name) as f:
        for line in f.readlines():
            current_line = line.strip().split('\t')
            float_line = list(map(float, current_line))
            data_mat.append(float_line)
    return data_mat


def bin_split_data_set(data_set, feature, value):
    mat0 = data_set[np.nonzero(data_set[:, feature] > value)[0], :]
    mat1 = data_set[np.nonzero(data_set[:, feature] <= value)[0], :]
    return mat0, mat1


def reg_leaf(data_set):
    return np.mean(data_set[:, -1])


def reg_err(data_set):
    return np.var(data_set[:, -1]) * np.shape(data_set)[0]


def choose_best_split(data_set, leaf_type=reg_leaf, err_type=reg_err, ops=(1, 4)):
    tol_S = ops[0];
    tol_N = ops[1]
    if len(set(data_set[:, -1].T.tolist()[0])) == 1:  # exit cond 1
        return None, leaf_type(data_set)
    m, n = np.shape(data_set)
    S = err_type(data_set)
    best_S = np.inf;
    best_index = 0;
    best_value = 0
    for feat_index in range(n - 1):
        for split_val in set(data_set[:, feat_index].T.tolist()[0]):
            mat0, mat1 = bin_split_data_set(data_set, feat_index, split_val)
            if (np.shape(mat0)[0] < tol_N) or (np.shape(mat1)[0] < tol_N):
                continue
            new_S = err_type(mat0) + err_type(mat1)
            if new_S < best_S:
                best_index = feat_index
                best_value = split_val
                best_S = new_S
    if (S - best_S) < tol_S:
        return None, leaf_type(data_set)
    mat0, mat1 = bin_split_data_set(data_set, best_index, best_value)
    if (np.shape(mat0)[0] < tol_N) or (np.shape(mat1)[0] < tol_N):
        return None, leaf_type(data_set)
    return best_index, best_value


def create_tree(data_set, leaf_type=reg_leaf, err_type=reg_err,
                ops=(1, 4)):
    feat, val = choose_best_split(data_set, leaf_type, err_type, ops)
    if feat is None:
        return val
    ret_tree = {}
    ret_tree['spInd'] = feat
    ret_tree['spVal'] = val
    lSet, rSet = bin_split_data_set(data_set, feat, val)
    ret_tree['left'] = create_tree(lSet, leaf_type, err_type, ops)
    ret_tree['right'] = create_tree(rSet, leaf_type, err_type, ops)
    return ret_tree


def get_reg_tree_values(tree):
    result = []
    sp_val = tree['spVal']
    left = tree['left']
    right = tree['right']
    result.append(sp_val)
    if type(left) is dict:
        result.extend(get_reg_tree_values(left))
    if type(right) is dict:
        result.extend(get_reg_tree_values(right))
    return result


def get_reg_tree_leaf_values(tree):
    result = []
    left = tree['left']
    right = tree['right']
    if type(left) is dict:
        result.extend(get_reg_tree_leaf_values(left))
    else:
        result.append(left)
    if type(right) is dict:
        result.extend(get_reg_tree_leaf_values(right))
    else:
        result.append(right)
    return result


def get__model_tree_values(tree):
    result = []
    left = tree['left']
    right = tree['right']
    if type(left) is dict:
        result.extend(get__model_tree_values(left))
    else:
        left_data = np.array(left)
        result.append([left_data[0][0], left_data[1][0]])
    if type(right) is dict:
        result.extend(get__model_tree_values(right))
    else:
        right_data = np.array(right)
        result.append([right_data[0][0], right_data[1][0]])
    return result


if __name__ == '__main__':
    data = load_data_set('ex00.txt')
    mat = np.mat(data)
    tree = create_tree(mat)
    print(tree)

執行結果:

D:\work\python_workspace\machine_learning\venv\Scripts\python.exe D:/work/python_workspace/machine_learning/tree_regression/reg_tree.py
{'spInd': 0, 'spVal': 0.48813, 'left': 1.0180967672413792, 'right': -0.04465028571428572}

Process finished with exit code 0

可以看出,原資料集被屬性值(屬性ID:0,值:0.48813)拆分成了兩半,左子樹的平均值為:1.0180967672413792,右子樹的平均值為:-0.04465028571428572。

下面我們畫圖來更直觀的顯示這些關係,建立模組reg_tree_plot.py,並輸入以下程式碼:

import numpy as np
import matplotlib.pyplot as plt

import tree_regression.reg_tree as reg_tree


def test_dataset1():
    data = reg_tree.load_data_set('ex00.txt')
    mat = np.mat(data)
    tree = reg_tree.create_tree(mat)
    print(tree)

    x = np.array(data)[:, 0]
    y = np.array(data)[:, 1]
    plt.plot(x, y, 'o', label='Original Values')

    line_x_arr = reg_tree.get_reg_tree_values(tree)
    plot_lines(line_x_arr, np.min(y), np.max(y))

    plot_mean_points(tree, np.min(x), np.max(x))

    plt.title('Regression Tree')
    plt.legend()

    plt.show()


def test_dataset2():
    data = reg_tree.load_data_set('ex0.txt')
    mat = np.mat(data)
    tree = reg_tree.create_tree(mat)
    # tree = reg_tree.create_tree(mat, ops=(0.1, 10))
    print(tree)

    x = np.array(data)[:, 1]
    y = np.array(data)[:, 2]
    plt.plot(x, y, 'o', label='Original Values')

    line_x_arr = reg_tree.get_reg_tree_values(tree)
    plot_lines(line_x_arr, np.min(y), np.max(y))

    plot_mean_points(tree, np.min(x), np.max(x))

    plt.title('Regression Tree')
    plt.legend()

    plt.show()


def plot_lines(x_arr, y_min, y_max):
    for x in x_arr:
        line_x = (x, x)
        line_y = (y_min, y_max)
        plt.plot(line_x, line_y, label='Split Line')


def plot_mean_points(tree, min_x, max_x):
    line_x_arr = reg_tree.get_reg_tree_values(tree)
    mean_y_values = reg_tree.get_reg_tree_leaf_values(tree)
    mean_y_values.sort()
    print(mean_y_values)
    mean_x_values = []
    tmp_x = [min_x]
    tmp_x.extend(line_x_arr)
    tmp_x.append(max_x)
    tmp_x.sort()
    print(tmp_x)
    index = 0
    while index < len(tmp_x) - 1:
        mean_x_values.append((tmp_x[index] + tmp_x[index + 1]) / 2)
        index += 1
    plt.plot(mean_x_values, mean_y_values, 'or', label='Mean Values')


if __name__ == '__main__':
    test_dataset1()
    #test_dataset2()

執行結果:

D:\work\python_workspace\machine_learning\venv\Scripts\python.exe D:/work/python_workspace/machine_learning/tree_regression/reg_tree_plot.py
{'spInd': 0, 'spVal': 0.48813, 'left': 1.0180967672413792, 'right': -0.04465028571428572}
[-0.04465028571428572, 1.0180967672413792]
[0.000234, 0.48813, 0.996757]

影像:

注意,上面兩個紅色的點,分別表示不同分類的平均值,中間的線是用於拆分的分割線。

上面程式碼中,我們測試dataset2,看看一個更復雜的場景:

if __name__ == '__main__':
    #test_dataset1()
    test_dataset2()

執行結果:

D:\work\python_workspace\machine_learning\venv\Scripts\python.exe D:/work/python_workspace/machine_learning/tree_regression/reg_tree_plot.py
{'spInd': 1, 'spVal': 0.39435, 'left': {'spInd': 1, 'spVal': 0.582002, 'left': {'spInd': 1, 'spVal': 0.797583, 'left': 3.9871632, 'right': 2.9836209534883724}, 'right': 1.980035071428571}, 'right': {'spInd': 1, 'spVal': 0.197834, 'left': 1.0289583666666666, 'right': -0.023838155555555553}}
[-0.023838155555555553, 1.0289583666666666, 1.980035071428571, 2.9836209534883724, 3.9871632]
[0.004327, 0.197834, 0.39435, 0.582002, 0.797583, 0.998709]

影像:

該資料集被拆分出了五塊,有五個葉子節點。

上面說了,通過修改引數可以控制葉子節點的個數(葉子節點是一個資料點的值,或者是一些資料點的平均值),修改上面函式test_dataset2()的程式碼增加引數 ops=(0.1, 10)

def test_dataset2():
    data = reg_tree.load_data_set('ex0.txt')
    mat = np.mat(data)
    # tree = reg_tree.create_tree(mat)
    tree = reg_tree.create_tree(mat, ops=(0.1, 10))
    print(tree)

    x = np.array(data)[:, 1]
    y = np.array(data)[:, 2]
    plt.plot(x, y, 'o', label='Original Values')

    line_x_arr = reg_tree.get_reg_tree_values(tree)
    plot_lines(line_x_arr, np.min(y), np.max(y))

    plot_mean_points(tree, np.min(x), np.max(x))

    plt.title('Regression Tree')
    plt.legend()

    plt.show()

執行結果:

D:\work\python_workspace\machine_learning\venv\Scripts\python.exe D:/work/python_workspace/machine_learning/tree_regression/reg_tree_plot.py
{'spInd': 1, 'spVal': 0.39435, 'left': {'spInd': 1, 'spVal': 0.582002, 'left': {'spInd': 1, 'spVal': 0.797583, 'left': 3.9871632, 'right': 2.9836209534883724}, 'right': {'spInd': 1, 'spVal': 0.486698, 'left': 2.0409245, 'right': 1.8810897500000001}}, 'right': {'spInd': 1, 'spVal': 0.197834, 'left': {'spInd': 1, 'spVal': 0.316465, 'left': 0.9437193846153846, 'right': 1.094141117647059}, 'right': {'spInd': 1, 'spVal': 0.148654, 'left': 0.07189454545454545, 'right': -0.054810500000000005}}}
[-0.054810500000000005, 0.07189454545454545, 0.9437193846153846, 1.094141117647059, 1.8810897500000001, 2.0409245, 2.9836209534883724, 3.9871632]
[0.004327, 0.148654, 0.197834, 0.316465, 0.39435, 0.486698, 0.582002, 0.797583, 0.998709]

影像:

可以看出,此時劃分出的資料集更多,構造出的數也更復雜,能夠更細粒度的進行分類。該模型對上面的兩個引數非常敏感,第一個引數控制比較資料時的相對大小,只要兩個資料點的大小在這個範圍內,就認為是相同的資料,第二個引數控制拆分出的子集最多允許多少個資料點。通過引入這兩個引數能夠控制迴歸樹的大小,這種方法稱為預剪枝方法。

樹剪枝

上一小節我們提到了預剪枝方法,該方法可以在構建迴歸樹的過程中進行剪枝。還有一種稱為後剪枝的方法,該方法基於一顆已經構建好的迴歸樹和一個測試集。這一小節我們討論後剪枝方法。

後剪枝方法的基本思想是先遍歷得到所有的葉子節點,再分別計算合併兩個葉子節點前的測試集誤差和合並兩個葉子節點後的測試集誤差,如果合併後的誤差小於合併前的誤差,則說明可以合併兩個葉子節點,否則不合並。

下面我們通過程式碼來演示。

建立模組tree_pruning.py,並輸入以下程式碼:

import numpy as np
import tree_regression.reg_tree as rt


def is_tree(obj):
    return type(obj) is dict


def get_mean(tree):
    if is_tree(tree['right']):
        tree['right'] = get_mean(tree['right'])
    if is_tree(tree['left']):
        tree['left'] = get_mean(tree['left'])
    return (tree['left'] + tree['right']) / 2.0


def prune(tree, test_data):
    if np.shape(test_data)[0] == 0:
        return get_mean(tree)  # if we have no test data collapse the tree
    if is_tree(tree['right']) or is_tree(tree['left']):  # if the branches are not trees try to prune them
        l_set, r_set = rt.bin_split_data_set(test_data, tree['spInd'], tree['spVal'])
    if is_tree(tree['left']):
        tree['left'] = prune(tree['left'], l_set)
    if is_tree(tree['right']):
        tree['right'] = prune(tree['right'], r_set)
    # if they are now both leafs, see if we can merge them
    if not is_tree(tree['left']) and not is_tree(tree['right']):
        l_set, r_set = rt.bin_split_data_set(test_data, tree['spInd'], tree['spVal'])
        error_no_merge = np.sum(np.power(l_set[:, -1] - tree['left'], 2)) + np.sum(
            np.power(r_set[:, -1] - tree['right'], 2))
        tree_mean = (tree['left'] + tree['right']) / 2.0
        error_merge = np.sum(np.power(test_data[:, -1] - tree_mean, 2))
        if error_merge < error_no_merge:
            print("merging...")
            return tree_mean
        else:
            return tree
    else:
        return tree


def test_pruning():
    print("Before pruning:")
    data = rt.load_data_set('ex2.txt')
    mat = np.mat(data)
    tree = rt.create_tree(mat)
    print(tree)
    print("After pruning:")
    test_mat = np.mat(rt.load_data_set('ex2test.txt'))
    prune_tree = prune(tree, test_mat)
    print(prune_tree)


if __name__ == '__main__':
    test_pruning()

執行結果:

D:\work\python_workspace\machine_learning\venv\Scripts\python.exe D:/work/python_workspace/machine_learning/tree_regression/tree_pruning.py
Before pruning:
{'spInd': 0, 'spVal': 0.499171, 'left': {'spInd': 0, 'spVal': 0.729397, 'left': {'spInd': 0, 'spVal': 0.952833, 'left': {'spInd': 0, 'spVal': 0.958512, 'left': 105.24862350000001, 'right': 112.42895575000001}, 'right': {'spInd': 0, 'spVal': 0.759504, 'left': {'spInd': 0, 'spVal': 0.790312, 'left': {'spInd': 0, 'spVal': 0.833026, 'left': {'spInd': 0, 'spVal': 0.944221, 'left': 87.3103875, 'right': {'spInd': 0, 'spVal': 0.85497, 'left': {'spInd': 0, 'spVal': 0.910975, 'left': 96.452867, 'right': {'spInd': 0, 'spVal': 0.892999, 'left': 104.825409, 'right': {'spInd': 0, 'spVal': 0.872883, 'left': 95.181793, 'right': 102.25234449999999}}}, 'right': 95.27584316666666}}, 'right': {'spInd': 0, 'spVal': 0.811602, 'left': 81.110152, 'right': 88.78449880000001}}, 'right': 102.35780185714285}, 'right': 78.08564325}}, 'right': {'spInd': 0, 'spVal': 0.640515, 'left': {'spInd': 0, 'spVal': 0.666452, 'left': {'spInd': 0, 'spVal': 0.706961, 'left': 114.554706, 'right': {'spInd': 0, 'spVal': 0.698472, 'left': 104.82495374999999, 'right': 108.92921799999999}}, 'right': 114.1516242857143}, 'right': {'spInd': 0, 'spVal': 0.613004, 'left': 93.67344971428572, 'right': {'spInd': 0, 'spVal': 0.582311, 'left': 123.2101316, 'right': {'spInd': 0, 'spVal': 0.553797, 'left': 97.20018024999999, 'right': {'spInd': 0, 'spVal': 0.51915, 'left': {'spInd': 0, 'spVal': 0.543843, 'left': 109.38961049999999, 'right': 110.979946}, 'right': 101.73699325000001}}}}}}, 'right': {'spInd': 0, 'spVal': 0.457563, 'left': {'spInd': 0, 'spVal': 0.467383, 'left': 12.50675925, 'right': 3.4331330000000007}, 'right': {'spInd': 0, 'spVal': 0.126833, 'left': {'spInd': 0, 'spVal': 0.373501, 'left': {'spInd': 0, 'spVal': 0.437652, 'left': -12.558604833333334, 'right': {'spInd': 0, 'spVal': 0.412516, 'left': 14.38417875, 'right': {'spInd': 0, 'spVal': 0.385021, 'left': -0.8923554999999995, 'right': 3.6584772500000016}}}, 'right': {'spInd': 0, 'spVal': 0.335182, 'left': {'spInd': 0, 'spVal': 0.350725, 'left': -15.08511175, 'right': -22.693879600000002}, 'right': {'spInd': 0, 'spVal': 0.324274, 'left': 15.05929075, 'right': {'spInd': 0, 'spVal': 0.297107, 'left': -19.9941552, 'right': {'spInd': 0, 'spVal': 0.166765, 'left': {'spInd': 0, 'spVal': 0.202161, 'left': {'spInd': 0, 'spVal': 0.217214, 'left': {'spInd': 0, 'spVal': 0.228473, 'left': {'spInd': 0, 'spVal': 0.25807, 'left': 0.40377471428571476, 'right': -13.070501}, 'right': 6.770429}, 'right': -11.822278500000001}, 'right': 3.4496025}, 'right': {'spInd': 0, 'spVal': 0.156067, 'left': -12.1079725, 'right': -6.247900000000001}}}}}}, 'right': {'spInd': 0, 'spVal': 0.084661, 'left': 6.509843285714284, 'right': {'spInd': 0, 'spVal': 0.044737, 'left': -2.544392714285715, 'right': 4.091626}}}}}
After pruning:
merging...
merging...
merging...
merging...
merging...
merging...
merging...
merging...
merging...
{'spInd': 0, 'spVal': 0.499171, 'left': {'spInd': 0, 'spVal': 0.729397, 'left': {'spInd': 0, 'spVal': 0.952833, 'left': {'spInd': 0, 'spVal': 0.958512, 'left': 105.24862350000001, 'right': 112.42895575000001}, 'right': {'spInd': 0, 'spVal': 0.759504, 'left': {'spInd': 0, 'spVal': 0.790312, 'left': {'spInd': 0, 'spVal': 0.833026, 'left': {'spInd': 0, 'spVal': 0.944221, 'left': 87.3103875, 'right': {'spInd': 0, 'spVal': 0.85497, 'left': {'spInd': 0, 'spVal': 0.910975, 'left': 96.452867, 'right': {'spInd': 0, 'spVal': 0.892999, 'left': 104.825409, 'right': {'spInd': 0, 'spVal': 0.872883, 'left': 95.181793, 'right': 102.25234449999999}}}, 'right': 95.27584316666666}}, 'right': {'spInd': 0, 'spVal': 0.811602, 'left': 81.110152, 'right': 88.78449880000001}}, 'right': 102.35780185714285}, 'right': 78.08564325}}, 'right': {'spInd': 0, 'spVal': 0.640515, 'left': {'spInd': 0, 'spVal': 0.666452, 'left': {'spInd': 0, 'spVal': 0.706961, 'left': 114.554706, 'right': 106.87708587499999}, 'right': 114.1516242857143}, 'right': {'spInd': 0, 'spVal': 0.613004, 'left': 93.67344971428572, 'right': {'spInd': 0, 'spVal': 0.582311, 'left': 123.2101316, 'right': 101.580533}}}}, 'right': {'spInd': 0, 'spVal': 0.457563, 'left': 7.969946125, 'right': {'spInd': 0, 'spVal': 0.126833, 'left': {'spInd': 0, 'spVal': 0.373501, 'left': {'spInd': 0, 'spVal': 0.437652, 'left': -12.558604833333334, 'right': {'spInd': 0, 'spVal': 0.412516, 'left': 14.38417875, 'right': 1.383060875000001}}, 'right': {'spInd': 0, 'spVal': 0.335182, 'left': {'spInd': 0, 'spVal': 0.350725, 'left': -15.08511175, 'right': -22.693879600000002}, 'right': {'spInd': 0, 'spVal': 0.324274, 'left': 15.05929075, 'right': {'spInd': 0, 'spVal': 0.297107, 'left': -19.9941552, 'right': {'spInd': 0, 'spVal': 0.166765, 'left': {'spInd': 0, 'spVal': 0.202161, 'left': -5.801872785714286, 'right': 3.4496025}, 'right': {'spInd': 0, 'spVal': 0.156067, 'left': -12.1079725, 'right': -6.247900000000001}}}}}}, 'right': {'spInd': 0, 'spVal': 0.084661, 'left': 6.509843285714284, 'right': {'spInd': 0, 'spVal': 0.044737, 'left': -2.544392714285715, 'right': 4.091626}}}}}

Process finished with exit code 0

可以看出,上面一共合併了9個葉子節點,能夠在一定程度上減少迴歸樹的大小。

模型樹

模型樹和迴歸樹非常類似,唯一不同的有兩個地方:

  1. 迴歸樹的葉子節點是一個資料點的值或一些資料點的平均值,而模型樹的葉子節點是一些資料點的線性迴歸引數
  2. 計算誤差時,迴歸樹基於平均值,也就是數學期望,而模型樹基於線性迴歸的預測值

可以這樣簡單理解,分類完成後,迴歸樹用子類的平均值來近似代替子類中的所有樣本點,而模型樹則使用迴歸方程來計運算元類中的樣本點。可以看出,迴歸樹可以更好的用於分類,而模型樹可以更好的用於預測。

下面我們通過程式碼來演示。

建立模組model_tree.py,並輸入以下程式碼:

import numpy as np
import matplotlib.pyplot as plt
import tree_regression.reg_tree as rt


def linear_solve(data_set):  # helper function used in two places
    m, n = np.shape(data_set)
    X = np.mat(np.ones((m, n)));
    Y = np.mat(np.ones((m, 1)))  # create a copy of data with 1 in 0th postion
    X[:, 1:n] = data_set[:, 0:n - 1];
    Y = data_set[:, -1]  # and strip out Y
    xTx = X.T * X
    if np.linalg.det(xTx) == 0.0:
        raise NameError('This matrix is singular, cannot do inverse, try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)
    return ws, X, Y


def model_leaf(data_set):
    ws, X, Y = linear_solve(data_set)
    return ws


def model_err(data_set):
    ws, X, Y = linear_solve(data_set)
    y_hat = X * ws
    return sum(np.power(Y - y_hat, 2))


def cal_line_y(line_value, x):
    return line_value[0] + line_value[1] * x


if __name__ == "__main__":
    my_mat = np.mat(rt.load_data_set('exp2.txt'))
    tree = rt.create_tree(my_mat, model_leaf, model_err)
    print(tree)

    line_values = rt.get__model_tree_values(tree)
    print(line_values)

    for line_value in line_values:
        x = np.linspace(-0.1, 1.2)
        y = cal_line_y(line_value, x)
        plt.plot(x, y, label="Regression Line: f(x)=%f * x + %f" % (line_value[1], line_value[0]))

    x = np.array(my_mat)[:, 0]
    y = np.array(my_mat)[:, 1]
    plt.plot(x, y, 'o', label='Original Values')

    plt.title('Model Tree')
    plt.legend()
    plt.show()

執行結果:

D:\work\python_workspace\machine_learning\venv\Scripts\python.exe D:/work/python_workspace/machine_learning/tree_regression/model_tree.py
{'spInd': 0, 'spVal': 0.285477, 'left': matrix([[1.69855694e-03],
        [1.19647739e+01]]), 'right': matrix([[3.46877936],
        [1.18521743]])}
[[0.0016985569360628006, 11.964773944277027], [3.4687793552577872, 1.1852174309188115]]

影像:

可以看出,該模型樹有兩個葉子節點,分別代表兩個迴歸方程,可以分段預測不同的樣本點。

分類和預測

通過遍歷迴歸樹和模型樹我們可以判斷某個樣本點屬於哪個分類或者其迴歸值是多少。建立模組forecast.py,並輸入以下程式碼:

import numpy as np
import tree_regression.tree_pruning as tp
import tree_regression.reg_tree as rt
import tree_regression.model_tree as mt


def reg_tree_eval(model, in_dat):
    return float(model)


def model_tree_eval(model, in_dat):
    n = np.shape(in_dat)[1]
    X = np.mat(np.ones((1, n + 1)))
    X[:, 1:n + 1] = in_dat
    return float(X * model)


def tree_forecast(tree, in_data, model_eval=reg_tree_eval):
    if not tp.is_tree(tree):
        return model_eval(tree, in_data)
    if in_data[tree['spInd']] > tree['spVal']:
        if tp.is_tree(tree['left']):
            return tree_forecast(tree['left'], in_data, model_eval)
        else:
            return model_eval(tree['left'], in_data)
    else:
        if tp.is_tree(tree['right']):
            return tree_forecast(tree['right'], in_data, model_eval)
        else:
            return model_eval(tree['right'], in_data)


def create_forecast(tree, test_data, model_eval=reg_tree_eval):
    m = len(test_data)
    y_hat = np.mat(np.zeros((m, 1)))
    for i in range(m):
        y_hat[i, 0] = tree_forecast(tree, np.mat(test_data[i]), model_eval)
    return y_hat


def test_reg_tree():
    print("Test regression tree:")
    data = rt.load_data_set('ex0.txt')
    tree = rt.create_tree(np.mat(data))
    print(tree)
    in_dat = (1.000000, 0.558918)
    y = 1.719148
    y_hat = tree_forecast(tree, in_dat)
    print("Real Y: %f" % y)
    print("Hat Y: %f" % y_hat)


def test_model_tree():
    print("Test model tree:")
    data = rt.load_data_set('exp2.txt')
    tree = rt.create_tree(np.mat(data), mt.model_leaf, mt.model_err)
    print(tree)
    in_dat = np.array([(0.010767,)])
    y = 3.565835
    y_hat = tree_forecast(tree, in_dat, model_eval=model_tree_eval)
    print("Real Y: %f" % y)
    print("Hat Y: %f" % y_hat)


if __name__ == '__main__':
    test_reg_tree()
    test_model_tree()

輸出:

D:\work\python_workspace\machine_learning\venv\Scripts\python.exe D:/work/python_workspace/machine_learning/tree_regression/forecast.py
Test regression tree:
{'spInd': 1, 'spVal': 0.39435, 'left': {'spInd': 1, 'spVal': 0.582002, 'left': {'spInd': 1, 'spVal': 0.797583, 'left': 3.9871632, 'right': 2.9836209534883724}, 'right': 1.980035071428571}, 'right': {'spInd': 1, 'spVal': 0.197834, 'left': 1.0289583666666666, 'right': -0.023838155555555553}}
Real Y: 1.719148
Hat Y: 1.980035
Test model tree:
{'spInd': 0, 'spVal': 0.285477, 'left': matrix([[1.69855694e-03],
        [1.19647739e+01]]), 'right': matrix([[3.46877936],
        [1.18521743]])}
Real Y: 3.565835
Hat Y: 3.481541

Process finished with exit code 0

小結

資料集中經常包含一些複雜的相互關係,使得資料和目標變數間呈現非線性關係。對這些複雜的資料集建模,一種有效的方式是使用樹來對預測值分段,包括分段常數和分段直線。分段常數是資料子集的平均值,一般用於分類,對應於迴歸樹。分段直線是資料子集的迴歸方程,一般用於預測,對應於模型樹。

CART採用二元樹來拆分資料集,如果出現過擬合的問題,可以通過樹剪枝的技術來減掉多餘的葉子節點。樹剪枝分為預剪枝和後剪枝,預剪枝是在構建樹的過程中減掉多餘的葉子節點,需要使用者指定兩個引數,而後剪枝基於一顆已經構建好的迴歸樹或模型樹以及一個測試集。

相關文章