前言
本文詳細介紹機器學習分類演算法中的決策樹演算法,並全面詳解如何構造,表示,儲存決策樹,以及如何使用決策樹進行分類等等問題。
為了全面的理解學習決策樹,本文篇幅較長,請耐心閱讀。
演算法原理
每次依據不同的特徵資訊對資料集進行劃分,劃分的最終結果是一棵樹。
該樹的每個子樹存放一個劃分集,而每個葉節點則表示最終分類結果,這樣一棵樹被稱為決策樹。
決策樹建好之後,帶著目標物件按照一定規則遍歷這個決策樹就能得到最終的分類結果。
該演算法可以分為兩大部分:
1. 構建決策樹部分
2. 使用決策樹分類部分
其中,第一部分是重點難點。
決策樹構造虛擬碼
1 # ============================================== 2 # 輸入: 3 # 資料集 4 # 輸出: 5 # 構造好的決策樹(也即訓練集) 6 # ============================================== 7 def 建立決策樹: 8 '建立決策樹' 9 10 if (資料集中所有樣本分類一致): 11 建立攜帶類標籤的葉子節點 12 else: 13 尋找劃分資料集的最好特徵 14 根據最好特徵劃分資料集 15 for 每個劃分的資料集: 16 建立決策子樹(遞迴方式)
核心問題一:依據什麼劃分資料集
可採用ID3演算法思路:如果以某種特種特徵來劃分資料集,會導致資料集發生最大程度的改變,那麼就使用這種特徵值來劃分。
那麼又該如何衡量資料集的變化程度呢?
可採用熵來進行衡量。這個字讀作shang,第一聲,不要讀成di啊,哈哈!
它用來衡量資訊集的無序程度,其計算公式如下:
其中:
1. x是指分類。要注意決策樹的分類是離散的。
2. P(x)是指任一樣本為該分類的概率
顯然,與原資料集相比,熵差最大的劃分集就是最優劃分集。
對資料集求熵的程式碼如下:
1 # ============================================== 2 # 輸入: 3 # dataSet: 資料集檔名(含路徑) 4 # 輸出: 5 # shannonEnt: 輸入資料集的夏農熵 6 # ============================================== 7 def calcShannonEnt(dataSet): 8 '計算夏農熵' 9 10 # 資料集個數 11 numEntries = len(dataSet) 12 # 標籤集合 13 labelCounts = {} 14 for featVec in dataSet: # 行遍歷資料集 15 # 當前標籤 16 currentLabel = featVec[-1] 17 # 加入標籤集合 18 if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 19 labelCounts[currentLabel] += 1 20 21 # 計算當前資料集的夏農熵並返回 22 shannonEnt = 0.0 23 for key in labelCounts: 24 prob = float(labelCounts[key])/numEntries 25 shannonEnt -= prob * log(prob,2) 26 27 return shannonEnt
可用如下函式建立測試資料集並對其求熵:
1 # ============================================== 2 # 輸入: 3 # 空 4 # 輸出: 5 # dataSet: 測試資料集列表 6 # ============================================== 7 def createDataSet(): 8 '建立測試資料集' 9 10 dataSet = [[1, 1, 'yes'], 11 [1, 1, 'yes'], 12 [1, 0, 'no'], 13 [0, 1, 'no'], 14 [0, 1, 'no']] 15 16 return dataSet 17 18 def test(): 19 '測試' 20 21 # 建立測試資料集 22 myDat = createDataSet() 23 # 求出其熵並列印 24 print calcShannonEnt(myDat)
執行結果如下:
如果我們修改測試資料集的某些資料,讓其看起來顯得混亂點,則得到的熵的值會更大。
還有其他描述集合無序程度的方法,比如說基尼不純度等等,這裡就不再討論了。
核心問題二:如何劃分資料集
這涉及到一些細節上面的問題了,比如:每次劃分是否需要剔除某些欄位?如何對各種劃分所得的熵差進行比較並進行最優劃分等等。
首先是具體劃分函式:
1 # ============================================== 2 # 輸入: 3 # dataSet: 訓練集檔名(含路徑) 4 # axis: 用於劃分的特徵的列數 5 # value: 劃分值 6 # 輸出: 7 # retDataSet: 劃分後的子列表 8 # ============================================== 9 def splitDataSet(dataSet, axis, value): 10 '資料集劃分' 11 12 # 劃分結果 13 retDataSet = [] 14 for featVec in dataSet: # 逐行遍歷資料集 15 if featVec[axis] == value: # 如果目標特徵值等於value 16 # 抽取掉資料集中的目標特徵值列 17 reducedFeatVec = featVec[:axis] 18 reducedFeatVec.extend(featVec[axis+1:]) 19 # 將抽取後的資料加入到劃分結果列表中 20 retDataSet.append(reducedFeatVec) 21 22 return retDataSet
然後是選擇最優劃分函式:
1 # =============================================== 2 # 輸入: 3 # dataSet: 資料集 4 # 輸出: 5 # bestFeature: 和原資料集熵差最大劃分對應的特徵的列號 6 # =============================================== 7 def chooseBestFeatureToSplit(dataSet): 8 '選擇最佳劃分方案' 9 10 # 特徵個數 11 numFeatures = len(dataSet[0]) - 1 12 # 原資料集夏農熵 13 baseEntropy = calcShannonEnt(dataSet) 14 # 暫存最大熵增量 15 bestInfoGain = 0.0; 16 # 和原資料集熵差最大的劃分對應的特徵的列號 17 bestFeature = -1 18 19 for i in range(numFeatures): # 逐列遍歷資料集 20 # 獲取該列所有特徵值 21 featList = [example[i] for example in dataSet] 22 # 將特徵值列featList的值唯一化並儲存到集合uniqueVals 23 uniqueVals = set(featList) 24 25 # 新劃分法夏農熵 26 newEntropy = 0.0 27 # 計算該特徵劃分下所有劃分子集的夏農熵,併疊加。 28 for value in uniqueVals: # 遍歷該特徵列所有特徵值 29 subDataSet = splitDataSet(dataSet, i, value) 30 prob = len(subDataSet)/float(len(dataSet)) 31 newEntropy += prob * calcShannonEnt(subDataSet) 32 33 # 儲存所有劃分法中,和原資料集熵差最大劃分對應的特徵的列號。 34 infoGain = baseEntropy - newEntropy 35 if (infoGain > bestInfoGain): 36 bestInfoGain = infoGain 37 bestFeature = i 38 39 return bestFeature
得到的結果是0:
而上面的程式碼也看到,測試資料集為:
1 dataSet = [[1, 1, 'yes'], 2 [1, 1, 'yes'], 3 [1, 0, 'no'], 4 [0, 1, 'no'], 5 [0, 1, 'no']]
顯然,按照第0列特徵劃分會更加合理,區分度更大。
核心問題三:如何具體實現樹結構
通過對前面兩個問題的分析,劃分資料集這一塊已經清楚明瞭了。
那麼如何用這些多層次的劃分子集搭建出一個樹結構呢?這部分更多涉及到程式設計技巧,某種程度上來說,就是用Python實現樹的問題。
在Python中,可以用字典來具體實現樹:字典的鍵存放節點資訊,值存放分支及子樹/葉子節點資訊。
比如說對於下面這個樹,用Python的字典表述就是:{'no surfacing' : {0, 'no', 1 : {'flippers' : {0 : 'no', 1 : 'yes'}}}}
如下構建樹部分程式碼。該函式呼叫後將形成決策樹:
1 # =============================================== 2 # 輸入: 3 # classList: 類標籤集 4 # 輸出: 5 # sortedClassCount[0][0]: 出現次數最多的標籤 6 # =============================================== 7 def majorityCnt(classList): 8 '採用多數表決的方式求出classList中出現次數最多的類標籤' 9 10 classCount={} 11 for vote in classList: 12 if vote not in classCount.keys(): classCount[vote] = 0 13 classCount[vote] += 1 14 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) 15 16 return sortedClassCount[0][0] 17 18 # =============================================== 19 # 輸入: 20 # dataSet: 資料集 21 # labels: 劃分標籤集 22 # 輸出: 23 # myTree: 生成的決策樹 24 # =============================================== 25 def createTree(dataSet,labels): 26 '建立決策樹' 27 28 # 獲得類標籤列表 29 classList = [example[-1] for example in dataSet] 30 31 # 遞迴終止條件一:如果資料集內所有分類一致 32 if classList.count(classList[0]) == len(classList): 33 return classList[0] 34 35 # 遞迴終止條件二:如果所有特徵都劃分完畢 36 if len(dataSet[0]) == 1: 37 # 將它們都歸為一類並返回 38 return majorityCnt(classList) 39 40 # 選擇最佳劃分特徵 41 bestFeat = chooseBestFeatureToSplit(dataSet) 42 # 最佳劃分對應的劃分標籤。注意不是分類標籤 43 bestFeatLabel = labels[bestFeat] 44 # 構建字典空樹 45 myTree = {bestFeatLabel:{}} 46 # 從劃分標籤列表中刪掉劃分後的元素 47 del(labels[bestFeat]) 48 # 獲取最佳劃分對應特徵的所有特徵值 49 featValues = [example[bestFeat] for example in dataSet] 50 # 對特徵值列表featValues唯一化,結果存於uniqueVals。 51 uniqueVals = set(featValues) 52 53 for value in uniqueVals: # 逐行遍歷特徵值集合 54 # 儲存所有劃分標籤資訊並將其夥同劃分後的資料集傳遞進下一次遞迴 55 subLabels = labels[:] 56 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels) 57 58 return myTree
如下程式碼可用於測試函式是否正常執行:
1 # ============================================== 2 # 輸入: 3 # 空 4 # 輸出: 5 # 用於測試的資料集和劃分標籤集 6 # ============================================== 7 def createDataSet(): 8 '建立測試資料集' 9 10 dataSet = [[1, 1, 'yes'], 11 [1, 1, 'yes'], 12 [1, 0, 'no'], 13 [0, 1, 'no'], 14 [0, 1, 'no']] 15 labels = ['no surfacing', 'flippers'] 16 17 return dataSet, labels 18 19 def test(): 20 '測試' 21 22 myDat, labels = createDataSet() 23 myTree = createTree(myDat, labels) 24 print myTree
執行結果:
使用Matplotlib繪製樹形圖
當決策樹構建好了以後,自然需要用一種方式來顯示給開發人員。僅僅是一個字典表示式很難讓人滿意。
因此,可採用Matplotlib來繪製樹形圖。
這涉及到兩方面的知識:
1. 遍歷樹獲取樹的高度,葉子數等資訊。
2. Matplotlib繪製影象的一些API
對於第一部分的任務,可以用遞迴的方式遍歷字典樹,從而獲得樹的相關資訊。
下面給出求樹的葉子樹及樹高的函式:
1 # =============================================== 2 # 輸入: 3 # myTree: 決策樹 4 # 輸出: 5 # numLeafs: 決策樹的葉子數 6 # =============================================== 7 def getNumLeafs(myTree): 8 '計算決策樹的葉子數' 9 10 # 葉子數 11 numLeafs = 0 12 # 節點資訊 13 firstStr = myTree.keys()[0] 14 # 分支資訊 15 secondDict = myTree[firstStr] 16 17 for key in secondDict.keys(): # 遍歷所有分支 18 # 子樹分支則遞迴計算 19 if type(secondDict[key]).__name__=='dict': 20 numLeafs += getNumLeafs(secondDict[key]) 21 # 葉子分支則葉子數+1 22 else: numLeafs +=1 23 24 return numLeafs 25 26 # =============================================== 27 # 輸入: 28 # myTree: 決策樹 29 # 輸出: 30 # maxDepth: 決策樹的深度 31 # =============================================== 32 def getTreeDepth(myTree): 33 '計算決策樹的深度' 34 35 # 最大深度 36 maxDepth = 0 37 # 節點資訊 38 firstStr = myTree.keys()[0] 39 # 分支資訊 40 secondDict = myTree[firstStr] 41 42 for key in secondDict.keys(): # 遍歷所有分支 43 # 子樹分支則遞迴計算 44 if type(secondDict[key]).__name__=='dict': 45 thisDepth = 1 + getTreeDepth(secondDict[key]) 46 # 葉子分支則葉子數+1 47 else: thisDepth = 1 48 49 # 更新最大深度 50 if thisDepth > maxDepth: maxDepth = thisDepth 51 52 return maxDepth
對於第二部分的任務 - 畫樹,其實本質就是畫點和畫線,下面給出基本的線畫法:
1 import matplotlib.pyplot as plt 2 3 decisionNode = dict(boxstyle="sawtooth", fc="0.8") 4 leafNode = dict(boxstyle="round4", fc="0.8") 5 arrow_args = dict(arrowstyle="<-") 6 7 # ================================================== 8 # 輸入: 9 # nodeTxt: 終端節點顯示內容 10 # centerPt: 終端節點座標 11 # parentPt: 起始節點座標 12 # nodeType: 終端節點樣式 13 # 輸出: 14 # 在圖形介面中顯示輸入引數指定樣式的線段(終端帶節點) 15 # ================================================== 16 def plotNode(nodeTxt, centerPt, parentPt, nodeType): 17 '畫線(末端帶一個點)' 18 19 createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) 20 21 def createPlot(): 22 '繪製有向線段(末端帶一個節點)並顯示' 23 24 # 新建一個圖物件並清空 25 fig = plt.figure(1, facecolor='white') 26 fig.clf() 27 # 設定1行1列個圖區域,並選擇其中的第1個區域展示資料。 28 createPlot.ax1 = plt.subplot(111, frameon=False) 29 30 # 畫線(末端帶一個節點) 31 plotNode('decisionNode', (0.5, 0.1), (0.1, 0.5), decisionNode) 32 plotNode('leafNode', (0.8, 0.1), (0.3, 0.8), leafNode) 33 34 # 顯示繪製結果 35 plt.show()
呼叫 createPlot 函式即可顯示繪製結果:
下面,將這兩部分內容整合起來,寫出最終繪製樹的程式碼:
1 import matplotlib.pyplot as plt 2 3 decisionNode = dict(boxstyle="sawtooth", fc="0.8") 4 leafNode = dict(boxstyle="round4", fc="0.8") 5 arrow_args = dict(arrowstyle="<-") 6 7 # ================================================== 8 # 輸入: 9 # nodeTxt: 終端節點顯示內容 10 # centerPt: 終端節點座標 11 # parentPt: 起始節點座標 12 # nodeType: 終端節點樣式 13 # 輸出: 14 # 在圖形介面中顯示輸入引數指定樣式的線段(終端帶節點) 15 # ================================================== 16 def plotNode(nodeTxt, centerPt, parentPt, nodeType): 17 '畫線(末端帶一個點)' 18 19 createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) 20 21 # ================================================================= 22 # 輸入: 23 # cntrPt: 終端節點座標 24 # parentPt: 起始節點座標 25 # txtString: 待顯示文字內容 26 # 輸出: 27 # 在圖形介面指定位置(cntrPt和parentPt中間)顯示文字內容(txtString) 28 # ================================================================= 29 def plotMidText(cntrPt, parentPt, txtString): 30 '在指定位置新增文字' 31 32 # 中間位置座標 33 xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] 34 yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] 35 36 createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) 37 38 # =================================== 39 # 輸入: 40 # myTree: 決策樹 41 # parentPt: 根節點座標 42 # nodeTxt: 根節點座標資訊 43 # 輸出: 44 # 在圖形介面繪製決策樹 45 # =================================== 46 def plotTree(myTree, parentPt, nodeTxt): 47 '繪製決策樹' 48 49 # 當前樹的葉子數 50 numLeafs = getNumLeafs(myTree) 51 # 當前樹的節點資訊 52 firstStr = myTree.keys()[0] 53 # 定位第一棵子樹的位置(這是蛋疼的一部分) 54 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) 55 56 # 繪製當前節點到子樹節點(含子樹節點)的資訊 57 plotMidText(cntrPt, parentPt, nodeTxt) 58 plotNode(firstStr, cntrPt, parentPt, decisionNode) 59 60 # 獲取子樹資訊 61 secondDict = myTree[firstStr] 62 # 開始繪製子樹,縱座標-1。 63 plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD 64 65 for key in secondDict.keys(): # 遍歷所有分支 66 # 子樹分支則遞迴 67 if type(secondDict[key]).__name__=='dict': 68 plotTree(secondDict[key],cntrPt,str(key)) 69 # 葉子分支則直接繪製 70 else: 71 plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW 72 plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) 73 plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) 74 75 # 子樹繪製完畢,縱座標+1。 76 plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD 77 78 # ============================== 79 # 輸入: 80 # myTree: 決策樹 81 # 輸出: 82 # 在圖形介面顯示決策樹 83 # ============================== 84 def createPlot(inTree): 85 '顯示決策樹' 86 87 # 建立新的影象並清空 - 無橫縱座標 88 fig = plt.figure(1, facecolor='white') 89 fig.clf() 90 axprops = dict(xticks=[], yticks=[]) 91 createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) 92 93 # 樹的總寬度 高度 94 plotTree.totalW = float(getNumLeafs(inTree)) 95 plotTree.totalD = float(getTreeDepth(inTree)) 96 97 # 當前繪製節點的座標 98 plotTree.xOff = -0.5/plotTree.totalW; 99 plotTree.yOff = 1.0; 100 101 # 繪製決策樹 102 plotTree(inTree, (0.5,1.0), '') 103 104 plt.show() 105 106 def test(): 107 '測試' 108 109 myDat, labels = createDataSet() 110 myTree = createTree(myDat, labels) 111 createPlot(myTree)
執行結果如下圖:
關於決策樹的儲存
這部分也很重要。
生成一個決策樹比較耗時間,誰也不想每次啟動程式都重新進行機器學習吧。那麼能否將學習結果 - 決策樹儲存到硬碟中去呢?
答案是肯定的,以下兩個函式分別實現了決策樹的儲存與開啟:
1 # ====================== 2 # 輸入: 3 # myTree: 決策樹 4 # 輸出: 5 # 決策樹檔案 6 # ====================== 7 def storeTree(inputTree,filename): 8 '儲存決策樹' 9 10 import pickle 11 fw = open(filename,'w') 12 pickle.dump(inputTree,fw) 13 fw.close() 14 15 # ======================== 16 # 輸入: 17 # filename: 決策樹檔名 18 # 輸出: 19 # pickle.load(fr): 決策樹 20 # ======================== 21 def grabTree(filename): 22 '開啟決策樹' 23 24 import pickle 25 fr = open(filename) 26 return pickle.load(fr)
使用決策樹進行分類
終於到了這一步,也是最終一步了。
拿到需要分類的資料後,遍歷決策樹直至葉子節點,即可得到分類結果,是不是很簡單呢?
下面給出遍歷及測試程式碼:
1 # ======================== 2 # 輸入: 3 # inputTree: 決策樹檔名 4 # featLabels: 分類標籤集 5 # testVec: 待分類物件 6 # 輸出: 7 # classLabel: 分類結果 8 # ======================== 9 def classify(inputTree,featLabels,testVec): 10 '使用決策樹分類' 11 12 # 當前分類標籤 13 firstStr = inputTree.keys()[0] 14 secondDict = inputTree[firstStr] 15 # 找到當前分類標籤在分類標籤集中的下標 16 featIndex = featLabels.index(firstStr) 17 # 獲取待分類物件中當前分類的特徵值 18 key = testVec[featIndex] 19 20 # 遍歷 21 valueOfFeat = secondDict[key] 22 23 # 子樹分支則遞迴 24 if isinstance(valueOfFeat, dict): 25 classLabel = classify(valueOfFeat, featLabels, testVec) 26 # 葉子分支則返回結果 27 else: classLabel = valueOfFeat 28 29 return classLabel 30 31 def test(): 32 '測試' 33 34 myDat, labels = createDataSet() 35 myTree = createTree(myDat, labels) 36 # 再建立一次資料的原因是建立決策樹函式會將labels值改動 37 myDat, labels = createDataSet() 38 print classify(myTree, labels, [1,1])
執行結果如下:
OK,一個完整的決策樹使用例子就實現了。
小結
1. 本文演示的是最經典ID3決策樹,但它在實際應用中存在過度匹配的問題。在以後的文章中會學習如何對決策樹進行裁剪。
2. 本文采用的ID3決策樹演算法只能用於標稱型資料。對於數值型資料,需要使用Cart決策樹構造演算法。這個演算法將在以後進行深入學習。