1 from math import log 2 import numpy as np 3 import matplotlib.pyplot as plt 4 import operator 5 6 #計算給定資料集的夏農熵 7 def calcShannonEnt(dataSet): 8 numEntries = len(dataSet) 9 labelCounts = {} 10 for featVec in dataSet: #| 11 currentLabel = featVec[-1] #| 12 if currentLabel not in labelCounts.keys(): #|獲取標籤類別取值空間(key)及出現的次數(value) 13 labelCounts[currentLabel] = 0 #| 14 labelCounts[currentLabel] += 1 #| 15 shannonEnt = 0.0 16 for key in labelCounts: #| 17 prob = float(labelCounts[key])/numEntries #|計算夏農熵 18 shannonEnt -= prob * log(prob, 2) #| 19 return shannonEnt 20 21 #建立資料集 22 def createDataSet(): 23 dataSet = [[1,1,'yes'], 24 [1,1,'yes'], 25 [1,0,'no'], 26 [0,1,'no'], 27 [0,1,'no']] 28 labels = ['no surfacing', 'flippers'] 29 return dataSet, labels 30 31 #按照給定特徵劃分資料集 32 def splitDataSet(dataSet, axis, value): 33 retDataSet = [] 34 for featVec in dataSet: #| 35 if featVec[axis] == value: #| 36 reducedFeatVec = featVec[:axis] #|抽取出符合特徵的資料 37 reducedFeatVec.extend(featVec[axis+1:]) #| 38 retDataSet.append(reducedFeatVec) #| 39 return retDataSet 40 41 #選擇最好的資料集劃分方式 42 def chooseBestFeatureToSplit(dataSet): 43 numFeatures = len(dataSet[0]) - 1 44 basicEntropy = calcShannonEnt(dataSet) 45 bestInfoGain = 0.0; bestFeature = -1 46 for i in range(numFeatures): #計算每一個特徵的熵增益 47 featlist = [example[i] for example in dataSet] 48 uniqueVals = set(featlist) 49 newEntropy = 0.0 50 for value in uniqueVals: #計算每一個特徵的不同取值的熵增益 51 subDataSet = splitDataSet(dataSet, i, value) 52 prob = len(subDataSet)/float(len(dataSet)) 53 newEntropy += prob * calcShannonEnt(subDataSet) #不同取值的熵增加起來就是整個特徵的熵增益 54 infoGain = basicEntropy - newEntropy 55 if (infoGain > bestInfoGain): #選擇最高的熵增益作為劃分方式 56 bestInfoGain = infoGain 57 bestFeature = i 58 return bestFeature 59 #挑選出現次數最多的類別 60 def majorityCnt(classList): 61 classCount={} 62 for vote in classList: 63 if vote not in classCount.keys(): 64 classCount[vote] = 0 65 classCount[vote] += 1 66 sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse=True) 67 return sortedClassCount[0][0] 68 69 def createTree(dataSet, labels): 70 classList = [example[-1] for example in dataSet] 71 if classList.count(classList[0]) == len(classList): #停止條件一:判斷所有類別標籤是否相同,完全相同則停止繼續劃分 72 return classList[0] 73 if len(dataSet[0]) == 1: #停止條件二:遍歷完所有特徵時返回出現次數最多的 74 return majorityCnt(classList) 75 bestFeat = chooseBestFeatureToSplit(dataSet) #得到列表包含的所有屬性值 76 bestFeatLabel = labels[bestFeat] 77 myTree = {bestFeatLabel:{}} 78 del(labels[bestFeat]) 79 featValues = [example[bestFeat] for example in dataSet] 80 uniqueVals = set(featValues) 81 for value in uniqueVals: 82 subLabels = labels[:] 83 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) 84 return myTree 85 86 # Simple unit test of func: createDataSet() 87 myDat, labels = createDataSet() 88 print (myDat) 89 #print (labels) 90 # Simple unit test of func: splitDataSet() 91 splitData = splitDataSet(myDat,0,1) 92 print (splitData) 93 # Simple unit test of func: chooseBestFeatureToSplit() 94 chooseResult = chooseBestFeatureToSplit(myDat) 95 print (chooseResult) 96 # Simple unit test of func: createTree( 97 myDat, labels = createDataSet() 98 myTree = createTree(myDat, labels) 99 print(myTree)
Output:
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
[[1, 'yes'], [1, 'yes'], [0, 'no']]
0
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
Reference:
《機器學習實戰》