樹迴歸|理論與演算法實現

weixin_33850890發表於2018-11-05
4263204-6961900320d42cfa.png

上一篇文章中,我們比較全面地學習了線性迴歸的原理是實現,今天我們還是留在迴歸板塊,針對樹迴歸進行學習和實踐。

01 樹迴歸原理

相比於線性迴歸,樹迴歸更適合對複雜、非線性的資料進行迴歸建模。

原理

回想一下決策樹,樹迴歸的原理就是決策樹(人家都叫”樹“迴歸了……),在決策樹的學習中,有三種演算法,ID3, C4.5, CART,前兩種演算法只能處理離散型資料,因此只能用於迴歸,而CART演算法由於採用二分法構建樹,可以處理連續性資料,因此也可以用於迴歸,樹迴歸的基本原理,就是CART演算法。

4263204-c0c3d24b66a10db8.png

混亂度衡量

說完演算法原理,我們再來說說對於連續性資料的混亂度度量。我們知道,對於離散型資料,可以使用資訊增益、資訊增益比、基尼指數這些指標來衡量資料的混亂程度,那麼對於連續型資料,怎麼衡量呢?

可以使用總方差來衡量連續性資料的混亂程度:(各資料-資料均值)**2,即方差*樣本數,稱為總方差。

兩種樹迴歸

樹迴歸有兩種方式:迴歸樹和模型樹,其中,

  • 迴歸樹:葉節點是一個值:當前葉子所有樣本標籤均值
  • 模型樹:葉節點是一個線性迴歸模型:當前葉子所有樣本的線性迴歸模型

下面我們分別實現迴歸樹和模型樹。


02 迴歸樹實現

  • 葉節點是一個值:當前葉子所有樣本標籤均值
  • 誤差衡量:總方差,表示一組資料的混亂度,是本組所有資料與這組資料均值之差的平方和

迴歸樹的構建邏輯:二分法,每次選擇一個最佳特徵,並找到最佳切分特徵值(使資料混亂度減少最多的[特徵,特徵值])進行切分,得到左右子樹,然後對左右子樹遞迴呼叫createTree方法,直到沒有最佳特徵為止。(實踐中,在選擇最佳特徵時,進行了預剪枝)

#資料讀取
def loadDataSet(filename):
    dataMat=[]
    fr=open(filename,'r')
    for line in fr.readlines():
        curLine=line.strip().split('\t')
        fltLine=list(map(float,curLine)) #將curLine各元素轉換為float型別
        dataMat.append(fltLine)
    return mat(dataMat)

#二分資料
def binSplitDataSet(dataset,feat,val):
    mat0=dataset[nonzero(dataset[:,feat]>val)[0],:] #陣列過濾選擇特徵大於指定值的資料
    mat1=dataset[nonzero(dataset[:,feat]<=val)[0],:] #陣列過濾選擇特徵小於指定值的資料
    return mat0,mat1

#定義迴歸樹的葉子(該葉子上各樣本標籤的均值)
def regLeaf(dataset):
    return mean(dataset[:,-1])

#定義連續資料的混亂度(總方差,即連續資料的混亂度=(該組各資料-該組資料均值)**2,即方差*樣本數))
def regErr(dataset):
    return var(dataset[:-1])*shape(dataset)[0]

"""最佳特徵以及最佳特徵值選擇函式"""
#leafType為葉節點取值,預設為regleaf,即取樣本標籤均值,對於模型樹,葉節點是一個線性模型
#errType為資料誤差(混亂度)計算方式,預設為regErr,總方差
#ops[0]為以最佳特徵及特徵值切分資料前後,資料混亂度的變化閾值,若小於該閾值,不切分
#ops[1]為切分後兩塊資料的最少樣本數,若少於該值,不切分
#可以預想,迴歸樹形狀對ops[0],ops[1]很敏感,若這兩個值過小,迴歸樹會很臃腫,過擬合
def chooseBestSplit(dataset,leafType=regLeaf,errType=regErr,ops=(1,4)):
    tolS=ops[0];tolN=ops[1];m,n=shape(dataset)
    S=errType(dataset);bestS=inf;beatIndex=0;bestVal=0
    if len(set(dataset[:,-1].T.tolist()[0]))==1: #若只有一個類別
        return None,leafType(dataset)
    for featIndex in range(n-1):
        for splitVal in set(dataset[:,featIndex].T.tolist()[0]):
            mat0,mat1=binSplitDataSet(dataset,featIndex,splitVal)
            #若切分後兩塊資料的最少樣本數少於設定值,不切分
            if (shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN): 
                continue
            newS=errType(mat0)+errType(mat1)
            if newS<bestS:
                bestIndex=featIndex;bestVal=splitVal;bestS=newS
    #若以最佳特徵及特徵值切分後的資料混亂度與原資料混亂度差值小於閾值,不切分
    if (S-bestS)<tolS:
        return None,leafType(dataset)
    mat0,mat1=binSplitDataSet(dataset,bestIndex,bestVal)
    #若以最佳特徵及特徵值切分後兩塊資料的最少樣本數少於設定值,不切分
    if (shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN):
        return None,leafType(dataset)
    return bestIndex,bestVal

"""構建迴歸樹"""
def createTree(dataset,leafType=regLeaf,errType=regErr,ops=(1,4)):
    feat,val=chooseBestSplit(dataset,leafType,errType,ops)
    if feat==None:
        return val
    regTree={}
    regTree['spFeat']=feat
    regTree['spVal']=val
    lSet,rSet=binSplitDataSet(dataset,feat,val)
    regTree['left']=createTree(lSet,leafType,errType,ops)
    regTree['right']=createTree(rSet,leafType,errType,ops)
    return regTree

好了好了,寫了這麼多程式碼,我們來測試一下,原始資料分佈如下圖,訓練結果如下圖。可以看到,迴歸樹模型將這組資料分到了5個葉節點上,目前看起來還過得去。

4263204-4801cbc13d770297.png
4263204-90362ebe792b6950.png

03 迴歸樹剪枝

當我們設定的最小分離葉節點樣本數、最小混亂度減小值等引數過小,可能產生過擬合,直觀的現象就是,訓練出來非常多的葉子,其實是沒有必要的,此時就需要剪枝了(很形象嘛)。

剪枝分為預剪枝和後剪枝,

  • 預剪枝:在chooseBestSplit函式中的幾個提前終止條件(切分樣本小於閾值、混亂度減弱小於閾值),都是預剪枝(引數敏感)。
  • 後剪枝:使用測試集對訓練出的迴歸樹進行剪枝(由於不需要使用者指定,後剪枝是一種更為理想化的剪枝方法)

後剪枝邏輯:對訓練好的迴歸樹,自上而下找到葉節點,用測試集來判斷將這些葉節點合併是否能降低測試誤差,若能,則合併。

#判斷是否是一棵樹(字典)
def isTree(obj):
    return (type(obj).__name__=='dict')

#得到樹所有葉節點的均值
def getMean(tree):
    #若子樹仍然是樹,則遞迴呼叫getMeant直到葉節點
    if isTree(tree['left']):
        tree['left']=getMean(tree['left'])
    if isTree(tree['right']):
        tree['right']=getMean(tree['right'])
    return (tree['left']+tree['right'])/2.0

"""剪枝函式:對訓練好的迴歸樹,自上而下找到葉節點,用測試集來判斷將這些葉節點合併是否能降低測試誤差,若能則合併"""
def prune(tree,testData):
    #若無測試資料,則直接返回樹所有葉節點的均值(塌陷處理)
    if shape(testData)[0]==0:
        return getMean(tree)
    #若存在任意子集是樹,則將測試集按當前樹的最佳切分特徵和特徵值切分(子集剪枝用)
    if isTree(tree['left']) or isTree(tree['right']):
        lSet,rSet=binSplitDataSet(testData,tree['spFeat'],tree['spVal'])
    #若存在任意子集是樹,則該子集遞迴呼叫剪枝過程(利用剛才切分好的訓練集)
    if isTree(tree['left']):
        tree['left']=prune(tree['left'],lSet)
    if isTree(tree['right']):
        tree['right']=prune(tree['right'],rSet)
    #若當前子集都是葉節點,則計算該二葉節點合併前後的誤差,決定是否合併
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet,rSet=binSplitDataSet(testData,tree['spFeat'],tree['spVal'])
        errNotMerge=sum(power(lSet[:,-1].T.tolist()[0]-tree['left'],2))+sum(power(rSet[:,-1].T.tolist()[0]-tree['right'],2))
        treeMean=(tree['left']+tree['right'])/2.0
        errMerge=sum(power(testData[:,-1].T.tolist()[0]-treeMean,2))
        if errMerge<errNotMerge:
            print("merging")
            return treeMean
        else:
            return tree
    else:
        return tree

測試一下

#構建迴歸樹,可以看到,該回歸樹非常臃腫,過擬合
dataMat3=loadDataSet(r'D:\DM\python\data\MLiA_SourceCode\machinelearninginaction\Ch09\ex2.txt')
regTree3=createTree(dataMat3,ops=(1,2))
testData=loadDataSet(r'D:\DM\python\data\MLiA_SourceCode\machinelearninginaction\Ch09\ex2test.txt')
prune(regTree3,testData)

結果如下,

4263204-385361e75c9e3790.png

可以看到,雖然有6個葉節點被剪掉了,但仍然有很多葉節點保留->後剪枝可能不如預剪枝有效,因此一般為了尋求最佳模型,會同時使用兩種剪枝技術。


04 模型樹實現

  • 葉節點是一個線性迴歸模型:當前葉子所有樣本的線性迴歸模型
  • 誤差衡量:平方誤差類比線性迴歸誤差,用線性模型對資料擬合,計算真實值與擬合值之差,求差值的平方和
  • 比迴歸樹有更好的可解釋性、更高的預測準確度

模型樹構建邏輯:通用函式,演算法邏輯與createTree()一致,只需改變其中的葉節點計算方法leafType()和誤差計算方法errType()

#葉節點計算方法:該葉節點所有樣本的標準線性迴歸模型,演算法與linearRegression()一致
def linearSolve(dataset):
    m,n=shape(dataset)
    X=mat(ones((m,n)));Y=mat(ones((m,1)))
    X[:,1:n]=dataset[:,0:n-1] #X第一列為常數項1
    Y=dataset[:,-1]
    xTx=X.T*X
    if linalg.det(xTx)==0.0:
        raise NameError("矩陣為奇異矩陣,不可逆,嘗試增大ops的第二個引數")
    ws=xTx.I*(X.T*Y)
    return ws,X,Y

def modelLeaf(dataset):
    ws,X,Y=linearSolve(dataset)
    return ws

#誤差計算方法:用線性模型對資料擬合,計算真實值與擬合值之差,求差值的平方和
def modelErr(dataset):
    ws,X,Y=linearSolve(dataset)
    yPred=X*ws
    return sum(power(Y-yPred,2))

測試一下,訓練集如圖藍點所示,訓練模型如圖紅線所示,可以看到,模型樹對資料的預測更準確合理。

dataMat4=loadDataSet(r'D:\DM\python\data\MLiA_SourceCode\machinelearninginaction\Ch09\exp2.txt')
modelTree1=createTree(dataMat4,leafType=modelLeaf,errType=modelErr,ops=(1,10))
4263204-2e70e224b6849dca.png

05 模型樹剪枝

同樣地,模型樹也會出現過擬合,也需要剪枝,原理與迴歸樹剪枝一樣,只需要替換其中的誤差計算方式,然後微調一下剪枝程式碼,讓每次遞迴時,對訓練資料也遞迴切分。

#判斷是否是一棵樹(字典)
def isTree(obj):
    return (type(obj).__name__=='dict')

#剪枝函式:對訓練好的模型樹,自上而下找到葉節點,用測試集來判斷將這些葉節點合併是否能降低測試誤差,若能則合併
def modelPrune(tree,trainData,testData):
    m,n=shape(testData)
    #若無測試資料,則直接返回樹所有葉節點的均值(塌陷處理)
    if m==0:
        return tree
    #若存在任意子集是樹,則將測試集按當前樹的最佳切分特徵和特徵值切分(子集剪枝用)
    #同時將訓練集也按當前樹的最佳切分特徵和特徵值切分(子集剪枝用)
    if isTree(tree['left']) or isTree(tree['right']):
        lSet,rSet=binSplitDataSet(testData,tree['spFeat'],tree['spVal'])
        lTrain,rTrain=binSplitDataSet(trainData,tree['spFeat'],tree['spVal'])
    #若存在任意子集是樹,則該子集遞迴呼叫剪枝過程(利用剛才切分好的訓練集)
    if isTree(tree['left']):
        tree['left']=modelPrune(tree['left'],lTrain,lSet)
    if isTree(tree['right']):
        tree['right']=modelPrune(tree['right'],rTrain,rSet)
        
    #若當前子集都是葉節點,則計算該二葉節點合併前後的誤差,決定是否合併
    """
    模型樹,兩個葉節點合併前的誤差=((左葉子真實值-擬合值)的平方和+(右葉子真實值-擬合值)的平方和)
    模型樹,兩個葉節點合併後的誤差=(左右真實值-左右擬合值)的平方和
    難點在於如何求左右擬合值,即求上層節點的迴歸係數wsMerge:用上層節點的traindata,通過linearSolve(traindata)求得
    上層節點的traindata在lTrain,rTrain的遞迴中已經求好了
    """
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet,rSet=binSplitDataSet(testData,tree['spFeat'],tree['spVal'])
        lSetX=mat(ones((shape(lSet)[0],n)));rSetX=mat(ones((shape(rSet)[0],n)))
        lSetX[:,1:n]=lSet[:,0:n-1];rSetX[:,1:n]=rSet[:,0:n-1]
        errNotMerge=sum(power(array(lSet[:,-1].T.tolist()[0])-lSetX*tree['left'],2))+sum(power(array(rSet[:,-1].T.tolist()[0])-rSetX*tree['right'],2))
        #難點在於求上層節點的迴歸係數wsMerge:用上層節點的traindata,通過linearSolve(traindata)求得
        wsMerge=modelLeaf(trainData)  
        testDataX=mat(ones((m,n)));testDataX[:,1:n]=testData[:,0:n-1]
        errMerge=sum(power(array(testData[:,-1].T.tolist()[0])-testDataX*wsMerge,2))
        if errMerge<errNotMerge:
            print("merging")
            return wsMerge
        else:
            return tree
    else:
        return tree

測試一下,

4263204-f368389f1e4d9980.png

06 模型預測效果對比

本次我們構建了迴歸樹和模型樹,順便構建了一個線性迴歸函式,我們先來看看對於同一組資料,這三個模型的預測效果吧,這裡使用R2值來評估預測效果。

4263204-787da472a5ef5901.png

結果如下,

4263204-2a1c6d7af2b31a11.png
  • 可以看到,這此資料集上,模型樹表現比迴歸樹好,線性迴歸表現最差
  • 說明樹迴歸相比於線性迴歸,可以更好地處理複雜、非線性的資料集

07 總結

至此,我們基本上學習了迴歸任務中80%以上的演算法模型,主要分為線性迴歸模型和樹迴歸模型(再回憶一下,KNN也可以用於迴歸),它們各有優劣,針對具有不同特點的資料集,要選擇合適的演算法。


08 參考

  • 《機器學習實戰》 Peter Harrington Chapter9

相關文章