《機器學習實踐》程式清單3-7 plotTree函式

王明輝發表於2018-02-09

這個plotTree函式,比較聰明,比較簡化,比較抽象,作者一定是逐步優化和簡化到這個程度的。我是花了小兩天時間,斷斷續續看明白的,還是在參考了另一篇文章以後。這裡是連結http://www.cnblogs.com/fantasy01/p/4595902.html。現在嘗試講明白。

總體思想是,找出來需要畫圖形的座標,用函式畫圖。圖形一共有三類,一類是父節點,一類是線條,一類是葉子結點。其中“畫圖”這個動作不難,用matplotlib中的畫圖功能,非常簡單。難的是計算座標。就像那個著名的斯坦門茨的故事,畫線1美元,知道在哪裡畫線,9999美元。在這裡,matplotlib中的函式就是那粉筆,而我們要知道的是在哪裡畫線。

這裡作者有個大前提,就是“居中”,所有的計算都是圍繞著這個前提來進行的。每一步計算都是為了居中於節點的所有葉子節點,比如某個節點A有6個葉結點,那麼這個節點A就位於這6個節點的正中間。

下面這個函式容易理解,在指定座標處新增文字。如果父節點座標已知,子節點座標已知,找到中間的位置不難。

#在父子節點間填充文字資訊
def plotMidText(cntrPt, parentPt, txtString):
    #書上原式是這樣寫的,但是計算之後其實就是求中點的公式(parentPt[0] +  cntrPt[0]) / 2.0 
   #書上體現的是中點所在座標的真正意義,用原點遠端點的x座標減掉近端點的x座標,得到差值,除以2,就是中點距離兩點的絕對距離,再加上近端點的x座標,就是中點距離原點的距離,
  #即中點的x座標
  #y座標同理
#xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0] xMid = (parentPt[0] + cntrPt[0]) / 2.0 #yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1] yMid = (parentPt[1] + cntrPt[1]) / 2.0 #在(xMid,yMid)座標處增加文字 createPlot.ax1.text(xMid, yMid, txtString)

下面就是比較難理解的plotTree部分

def plotTree(myTree, parentPt, nodeTxt):

    numLeafs = getNumLeafs(myTree) #遞迴取葉結點數
    depth = getTreeDepth(myTree) #遞迴取樹的深度(層數)
    
    print("葉子數:", numLeafs)
    print("層數:", depth)

    print("xOff:", plotTree.xOff)
    #這一步的結果是一個座標,(0.5,1.0),子節點的所在位置,為什麼要這樣計算?
    #這一步跳過了中間的很多步驟,此式是大量過程化簡的結果
    cntrPt =  (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    print("cntrPt",cntrPt)

    #在父節點和子節點確定之後,在父子之間做文字標記,即nodeTxt
  #第一層節點的父座標與節點座標相同,其實畫了一個長度為0的線,nodeText是空,如果想試驗,可以在下面的createTree函式裡設定
  #plotTree(inTree, parentPt, '中華人民共和國中華人民共和國')
  #它就原形畢露了
plotMidText(cntrPt, parentPt, nodeTxt) firstStr = list(myTree.keys())[0] #每層樹的首節點名稱 plotNode(firstStr , cntrPt, parentPt, decisionNode) #plotNode(firstStr + "[" + str(round(cntrPt[0],2)) + "," + str(round(cntrPt[1],2)) + "]", cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD for key in secondDict.keys(): print("secondDict.keys()中的key:", list(secondDict.keys())[key]) if (type(secondDict[key])).__name__ == 'dict': #字典的值是也是字典(樹),繼續遞迴 plotTree(secondDict[key], cntrPt, str(key)) else: #如果字典的值不是字典(是葉子),則直接輸出葉子 # plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW #plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) a = (plotTree.xOff, plotTree.yOff) plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

 

下面是主程式

def createPlot(inTree):
    fig = plt.figure(1, facecolor = 'white') #這句可以不寫,如果不寫,會預設建立一個
    #fig = plt.figure(1, facecolor = 'red') #此句會啟用1號figure,facecolor白色
    #fig = plt.figure(2, facecolor = 'red') #此句會重新建立一個facecolor為紅色的figure
    
    fig = clf() #清空圖形區(plot區、工作區),可能是clear figure的縮寫

    axprops = dict(xticks=[], yticks=[]) #此參數列示座標刻度,[]表示不顯示刻度,可以作為引數傳入,也可以用plt.xticks([1,3,4])單獨設定

    createPlot.ax1 = plt.subplot(111, frameon = True)#, **axprops) # **表示此引數是字典引數
    #plt.xticks([1,3,4],"a,b") #單獨設定刻度
    #print(axprops)

    #================================================================================================

    plotTree.totalW = float(getNumLeafs(inTree)) #全域性變數plotTree.totalW用於儲存樹的寬度,葉子數
    print("總葉子數(寬度):", plotTree.totalW)
    plotTree.totalD = float(getTreeDepth(inTree)) #全域性變數plotTree.totalD用於儲存樹的深度
    print("總層數:", plotTree.totalD)

    #追蹤已經繪製的節點位置,x軸上的偏移量。這只是用於方便計算的一個偏移量,沒有實際意義,設定這樣一個值以後,後面的只需要加上葉節點的個數就可以了。
    #如果0.5不太容易理解,(1/2)*(1/plotTree.totalW),也就是把x軸分為plotTree.totalW份後,其中的1份的一半。
    plotTree.xOff = -0.5 / plotTree.totalW; 
    #追蹤已經繪製的節點位置,y軸上的偏移量
    plotTree.yOff = 1.0

    parentPt = (0.5,1.0) #頂層節點的座標

    plotTree(inTree, parentPt, '')
    #plt.axis([0,10,0,10])
    plt.show()

下面是sublime中的呼叫程式碼

def testCreatePlot():

    inTree = retrieveTree(0)
    createPlot(inTree)

 

上面程式碼中,最核心的是座標的計算過程。圖形在一個x軸和y軸的長度各為1的一個座標系中繪製。首先計算出葉子節點的數量(為什麼要計算這個數量?是因為葉節點需要展開,它們所需要的總寬度是最大的),因為x軸的長度是1,所以用1去除以葉節點的數量,得到每個葉節點所需要的長度,如果x軸的總長度是10,那就用10去除以葉節點的數量,總之這步是在求每個葉子在x軸上所需要的長度。求解思路如下(參考上面所提到的文章):

1、其中方形為非葉子節點的位置,@是葉子節點的位置,因此每份(即上圖的一個單元格)的長度應該為1/plotTree.totalW,但是葉子節點的位置應該為@所在位置,則在開始的時候plotTree.xOff的賦值為-0.5/plotTree.totalW,即意為開始x位置為第一個表格左邊的半個表格距離位置,這樣作的好處是:在以後確定@位置時候可以直接加整數倍的1/plotTree.totalW。

這一步一定是經過了作者的逐步優化才得到的。如果不這樣做,那麼每次取@所在的座標時,都需要減掉左側第一個@左邊至原點這半個格, 所以作者設定了一個偏移量,以後只需要直接加1個完整的份數,即1/plotTree.totalW,就是下一個葉節點的x座標,聰明。

2、對於本演算法的核心,plotTree函式中的紅色部分即如下:

cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)

這一步的cntrPt求的是節點所在座標(x,y)。plotTree.xOff 即為最近繪製的一個葉子節點的x座標,在確定當前節點位置時每次只需確定當前節點有幾個葉子節點,因此其葉子節點所佔的總距離就確定了即為float(numLeafs)*(1/plotTree.totalW)(因為總長度為1,如果是總長度是10就用10作分子),比如有4個葉節點,總共有6份,那麼所佔距離就是4*(1/6),因此當前節點的位置即為其所有葉子節點所佔距離的中間,即一半,(float(numLeafs)/2.0)(1/plotTree.totalW),但是由於開始plotTree.xOff賦值並非從0開始,而是左移了半個單元格,因此還需加回來半個單元格距離,即(1/2)(1/plotTree.totalW),計算結果就是(1.0 + float(numLeafs))/2.0/plotTree.totalW*1,因此偏移量確定,則x位置變為plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW。

3、對於plotTree函式引數賦值為(0.5, 1.0)的解釋

因為開始的根節點並不用劃線,因此父節點和當前節點的位置需要重合,利用2中的確定當前節點的位置便為(0.5, 1.0)

總結:利用這樣的逐漸增加x的座標,以及逐漸降低y的座標能能夠很好的將樹的葉子節點數和深度考慮進去,因此圖的邏輯比例就很好的確定了,這樣不用去關心輸出圖形的大小,一旦圖形發生變化,函式會重新繪製,但是假如利用畫素為單位來繪製圖形,這樣縮放圖形就比較有難度了

 

相關文章