《機器學習實戰》程式清單3-4 建立樹的函式程式碼

王明輝發表於2018-02-05

有點亂,等我徹底想明白時再來整理清楚。

 

from math import log
import operator

def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    #print("樣本總數:" + str(numEntries))

    labelCounts = {} #記錄每一類標籤的數量

    #定義特徵向量featVec
    for featVec in dataSet:
        
        currentLabel = featVec[-1] #最後一列是類別標籤

        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0;

        labelCounts[currentLabel] += 1 #標籤currentLabel出現的次數
        #print("當前labelCounts狀態:" + str(labelCounts))

    shannonEnt = 0.0

    for key in labelCounts:
        
        prob = float(labelCounts[key]) / numEntries #每一個類別標籤出現的概率

        #print(str(key) + "類別的概率:" + str(prob))
        #print(prob * log(prob, 2) )
        shannonEnt -= prob * log(prob, 2) 
        #print("熵值:" + str(shannonEnt))

    return shannonEnt

def createDataSet():
    dataSet = [
        [1, 1, 'yes'],
        [1, 1, 'yes'],
        [1, 0, 'no'],
        [0, 1, 'no'],
        [0, 1, 'no'],
        #以下隨意新增,用於測試熵的變化,越混亂越衝突,熵越大
        # [1, 1, 'no'],
        # [1, 1, 'no'],
        # [1, 1, 'no'],
        # [1, 1, 'no'],
        #[1, 1, 'maybe'],
        # [1, 1, 'maybe1']
        # 用下面的8個比較極端的例子看得會更清楚。如果按照這個規則繼續增加下去,熵會繼續增大。
        # [1,1,'1'],
        # [1,1,'2'],
        # [1,1,'3'],
        # [1,1,'4'],
        # [1,1,'5'],
        # [1,1,'6'],
        # [1,1,'7'],
        # [1,1,'8'],

        # 這是另一個極端的例子,所有樣本的類別是一樣的,有序,不混亂,此時熵為0
        # [1,1,'2'],
        # [1,1,'1'],
        # [1,1,'1'],
        # [1,1,'1'],
        # [1,1,'1'],
        # [1,1,'1'],
        # [1,1,'1'],
        # [1,1,'1'],        
    ]

    #print("dataSet[0]:" + str(dataSet[0]))
    #print(dataSet)

    labels = ['no surfacing', 'flippers']

    return dataSet, labels

def testCalcShannonEnt():

    myDat, labels = createDataSet()
    #print(calcShannonEnt(myDat))

def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        #print("featVec:" + str(featVec))
        #print("featVec[axis]:" + str(featVec[axis]))
        if featVec[axis] == value:
            reduceFeatVec = featVec[:axis]
            #print(reduceFeatVec)
            reduceFeatVec.extend(featVec[axis + 1:])
            #print('reduceFeatVec:' + str(reduceFeatVec))
            retDataSet.append(reduceFeatVec)
    #print("retDataSet:" + str(retDataSet))
    return retDataSet

def testSplitDataSet():
    myDat,labels = createDataSet()
    #print(myDat)
    a = splitDataSet(myDat, 0, 0)
    #print(a)


def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1 #減掉類別列,剩2列
    #print("特徵數量:" + str(numFeatures))

    baseEntropy = calcShannonEnt(dataSet)
    #print("基礎熵:" + str(baseEntropy))

    bestInfoGain = 0.0;
    bestFeature  = -1

    #numFeatures==2
    for i in range(numFeatures):
        #print("i的值" + str(i))
        featList = [example[i] for example in dataSet];
        #print("featList:" + str(featList))

        #在列表中建立集合是Python語言得到列表中唯一元素值的最快方法
        #集合物件是一組無序排列的可雜湊的值。集合化,收縮
        #[1, 0, 1, 1, 1, 1]建立集合後,變為{0,1}
        uniqueVals = set(featList) 
        #print("uniqueVals" + str(uniqueVals))

        newEntropy = 0.0
        #uniqueVals=={0,1}
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            #print("subDataSet:" + str(subDataSet))
            prob = len(subDataSet) / float(len(dataSet))
            
            #print("subDataSet:" + str(subDataSet))
            #print("subDataSet的長度:" + str(len(subDataSet)))
            newEntropy += prob * calcShannonEnt(subDataSet)
            #print("newEntropy:" + str(newEntropy))

        #資訊增益,新序列熵越小,增益越大,最終目標是把最大的增益找出來
        infoGain = baseEntropy - newEntropy 
        #print("infoGain:" + str(infoGain))
        #print("bestInfoGain:" + str(bestInfoGain))


        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i

    #print("bestFeature:" + str(bestFeature))
    return bestFeature
            
    
def testChooseBestFeatureToSplit():
    myDat, labels = createDataSet()
    chooseBestFeatureToSplit(myDat)

'''
輸入:類別列表     
輸出:類別列表中多數的類,即多數表決
這個函式的作用是返回字典中出現次數最多的value對應的key,也就是輸入list中出現最多的那個值
'''
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys(): 
            classCount[vote] = 0
        classCount[vote] += 1
 
     #key=operator.itemgetter(0)或key=operator.itemgetter(1),決定以字典的鍵排序還是以字典的值排序
     #0以鍵排序,1以值排序
     #reverse(是否反轉)預設是False,reverse == true 則反轉由大到小排列

    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)

    print(sortedClassCount)

    return sortedClassCount[0][0]
def testMajorityCnt():
     list1 = ['a','b','a','a','b','c','d','d','d','e','a','a','a','a','c','c','c','c','c','c','c','c']
    
     print(majorityCnt(list1))

global n
n=0

def createTree(dataSet, labels):
    
    global n
    print("=================createTree"+str(n)+" begin=============")
    n += 1
    print(n)

    classList = [example[-1] for example in dataSet]

    print("" + str(n) + "次classList:" + str(classList))
    print("此時列表中的第1個元素為" + str(classList[0]) + ",數量為:" + str(classList.count(classList[0])) + ",列表總長度為:" + str(len(classList)))
    
    print("列表中"+str(classList[0])+"的數量:",classList.count(classList[0]))
    print("列表的長度:", len(classList))

    if classList.count(classList[0])== len(classList):
        print("判斷結果為:所有類別相同,停止本組劃分")
    else:
        print("判斷結果為:類別不相同")

     #列表中有n個元素,並且n個都一致,則停止遞迴 
    if classList.count(classList[0]) == len(classList):
         return classList[0]

    print("dataSet[0]:" + str(dataSet[0]))

    if len(dataSet[0]) == 1:
        print("啟動多數表決")  #書中的示例樣本集合沒有觸發
        return majorityCnt(classList)

    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    print("bestFeat:" +str(bestFeat))
    print("bestFeatLabel:" + str(bestFeatLabel))

    myTree = {bestFeatLabel:{}}
    print("當前樹狀態:" + str(myTree))

    print("當前標籤集合:" + str(labels))
    print("準備刪除" + labels[bestFeat])
    del(labels[bestFeat])
    print("已刪除")
    print("刪除元素後的標籤集合:" + str(labels))

    featValues = [example[bestFeat] for example in dataSet]
    print("featValues:",featValues)

    uniqueVals = set(featValues)
    print("uniqueVals:", uniqueVals) #{0,1}

    k = 0
    print("********開始迴圈******")
    for value in uniqueVals:

        k += 1
        print("",k,"次迴圈")
        subLabels = labels[:]
        print("傳入引數:")
        print("        --待劃分的資料集:",dataSet)
        print("        --劃分資料集的特徵:", bestFeat)
        print("        --需要返回的符合特徵值:", value)
        splited = splitDataSet(dataSet, bestFeat, value)
        print("splited:", str(splited))
        myTree[bestFeatLabel][value] = createTree(splited, subLabels)  #遞迴呼叫
    print("*******結束迴圈*****")

    print("=================createTree"+str(n)+" end=============")
    return myTree
     
def testCreateTree():
     
     myDat,labels = createDataSet();
     myTree = createTree(myDat, labels);
     print("============testCreateTree=============")
     print(myTree)

if __name__ == '__main__':
    #測試輸出資訊熵
    #testCalcShannonEnt()

    #測試拆分結果集
    #testSplitDataSet()

    #選擇最好的特徵值
    #testChooseBestFeatureToSplit()
 
    #testMajorityCnt()

    testCreateTree()


    

 

相關文章