決策樹ID3演算法python實現 -- 《機器學習實戰》

劉川楓發表於2017-11-13
 1 from math import log
 2 import numpy as np
 3 import matplotlib.pyplot as plt
 4 import operator
 5 
 6 #計算給定資料集的夏農熵
 7 def calcShannonEnt(dataSet):
 8     numEntries = len(dataSet)
 9     labelCounts = {}
10     for featVec in dataSet:                         #|
11         currentLabel = featVec[-1]                  #|
12         if currentLabel not in labelCounts.keys():  #|獲取標籤類別取值空間(key)及出現的次數(value)
13             labelCounts[currentLabel] = 0           #|
14         labelCounts[currentLabel] += 1              #|
15     shannonEnt = 0.0
16     for key in labelCounts:                         #|
17         prob = float(labelCounts[key])/numEntries   #|計算夏農熵
18         shannonEnt -= prob * log(prob, 2)           #|
19     return shannonEnt
20 
21 #建立資料集
22 def createDataSet():
23     dataSet = [[1,1,'yes'],
24                [1,1,'yes'],
25                [1,0,'no'],
26                [0,1,'no'],
27                [0,1,'no']]
28     labels = ['no surfacing', 'flippers']
29     return dataSet, labels
30 
31 #按照給定特徵劃分資料集
32 def splitDataSet(dataSet, axis, value):
33     retDataSet = []
34     for featVec in dataSet:                         #|
35         if featVec[axis] == value:                  #|
36             reducedFeatVec = featVec[:axis]         #|抽取出符合特徵的資料
37             reducedFeatVec.extend(featVec[axis+1:]) #|
38             retDataSet.append(reducedFeatVec)       #|
39     return retDataSet
40 
41 #選擇最好的資料集劃分方式
42 def chooseBestFeatureToSplit(dataSet):
43     numFeatures = len(dataSet[0]) - 1
44     basicEntropy = calcShannonEnt(dataSet)
45     bestInfoGain = 0.0; bestFeature = -1
46     for i in range(numFeatures):        #計算每一個特徵的熵增益
47         featlist = [example[i] for example in dataSet]
48         uniqueVals = set(featlist)
49         newEntropy = 0.0
50         for value in uniqueVals:        #計算每一個特徵的不同取值的熵增益
51             subDataSet = splitDataSet(dataSet, i, value)
52             prob = len(subDataSet)/float(len(dataSet))
53             newEntropy += prob * calcShannonEnt(subDataSet) #不同取值的熵增加起來就是整個特徵的熵增益
54         infoGain = basicEntropy - newEntropy
55         if (infoGain > bestInfoGain):   #選擇最高的熵增益作為劃分方式
56             bestInfoGain = infoGain
57             bestFeature = i
58     return bestFeature
59 #挑選出現次數最多的類別
60 def majorityCnt(classList):
61     classCount={}
62     for vote in classList:
63         if vote not in classCount.keys():
64             classCount[vote] = 0
65         classCount[vote] += 1
66     sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse=True)
67     return sortedClassCount[0][0]
68 
69 def createTree(dataSet, labels):
70     classList = [example[-1] for example in dataSet]
71     if classList.count(classList[0]) == len(classList): #停止條件一:判斷所有類別標籤是否相同,完全相同則停止繼續劃分
72         return classList[0]
73     if len(dataSet[0]) == 1:    #停止條件二:遍歷完所有特徵時返回出現次數最多的
74         return majorityCnt(classList)
75     bestFeat = chooseBestFeatureToSplit(dataSet)    #得到列表包含的所有屬性值
76     bestFeatLabel = labels[bestFeat]
77     myTree = {bestFeatLabel:{}}
78     del(labels[bestFeat])
79     featValues = [example[bestFeat] for example in dataSet]
80     uniqueVals = set(featValues)
81     for value in uniqueVals:
82         subLabels = labels[:]
83         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
84     return myTree
85 
86 # Simple unit test of func: createDataSet()
87 myDat, labels = createDataSet()
88 print (myDat)
89 #print (labels)
90 # Simple unit test of func: splitDataSet()
91 splitData = splitDataSet(myDat,0,1)
92 print (splitData)
93 # Simple unit test of func: chooseBestFeatureToSplit()
94 chooseResult = chooseBestFeatureToSplit(myDat)
95 print (chooseResult)
96 # Simple unit test of func: createTree(
97 myDat, labels = createDataSet()
98 myTree = createTree(myDat, labels)
99 print(myTree)

Output:

[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
[[1, 'yes'], [1, 'yes'], [0, 'no']]
0
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

 

Reference:

《機器學習實戰》

相關文章