機器學習實戰(三)決策樹ID3:樹的構建和簡單分類

StayFoolishAndHappy發表於2018-05-17
說明:主要參考 機器學習實戰之決策樹,裡面程式碼的實現和詳細的註釋,是一個關於機器學習實戰的不錯的學習資料,推薦一波。出於程式設計實踐和機器學習演算法梳理的目的,按照自己的程式碼風格重寫該演算法,在實現的過程中也很有助於自己的思考。為方便下次看時能快速理解便通過截圖的方式擷取了個人認為比較關鍵的內容,推薦看原連結,自己在程式碼實現過程中會留下一些思考,也歡迎交流學習。

還有個寫得更為詳細的資料可以參考一下:ID3演算法實現,每一步都有很詳細的計算。

相關知識點

下面兩個圖是自己以前看決策樹的時候整的PPT裡面的兩頁,主要是說的是ID3和相關概念,至於C4.5,CART,GBDT, RandomForest等內容就不貼上來了,ID3的原理還是很簡單的,可以找到很多其他資料,這裡呢主要還是側重於程式設計實踐。




ID3演算法實現

下面程式碼的主要參考連結在上面已經給出了,詳細的分析建議看原連結,這份程式碼應該算是最最最simple的一個例子了,沒有剪枝和非離散化的資料處理,更沒有不完整資料的處理,只是簡單的構建幾個資料實現ID3樹的構建和分類決策。

#  __author__ = 'czx'
# coding=utf-8
"""
Description:
    ID3 Algorithm for fish classification task .
"""
from numpy import *
from math import log

def createData():
    """
    :return:
        data: including feature values and class values
        labels: description of features
    """
    data = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [1, 0, 'no'], [0, 1, 'no']]
    labels = ['no surfing', 'flippers']
    return data, labels

def calcShannonEnt( data ):
    """
    :param
        data: given data set
    :return:
        shannonent: shannon entropy of the given data set
    """
    num = len(data)                             # data set size
    labelcount = {}                             # count number of each class ['yes','no']
    for sample in data:
        templabel = sample[-1]                  # class label of sample in data ['yes','no']
        if templabel not in labelcount.keys():  # add to dict
            labelcount[templabel] = 0           # initial value is 0
        labelcount[templabel] += 1              # count
    shannonent = 0.0                            # initial shannon entropy
    for key in labelcount:                      # for all classes
        prob = float(labelcount[key])/num       # Pi = Ni/N
        shannonent -= prob * log(prob, 2)       # Ent = Addup(Pi) (i=2) ,bacause classes:['yes','no']
    return shannonent                           # shannon entropy of the given data

def spiltDataSet(data,index,value):
    """
    :param
        data: the given data set
        index: index of selected feature
        value: the selected value to spilt the data set
    :return:
        resData: result of spilt data set  
    """
    resData = []
    for sample in data:                           # for all samples in data set
        # Mention that ID3 algorithm can only handle features with discrete values 
        if sample[index] == value:                # the selected feature value of sample
            spiltSample = sample[:index]          # first index features
            spiltSample.extend(sample[index+1:])  # last all features except feature[index]
            resData.append(spiltSample)
    return resData

def chooseBestFeatureToSpilt(data):
    """
    :param
        data: the given data set
    :return:
        bestFeature: the index of best feature which has best info gain
    """
    num = len(data[0])-1                               # all feature index [final column is class value]
    baseEntropy = calcShannonEnt(data)                 # initial shannon entropy of biggest data set
    bestInfoGain = 0.0
    bestFeature = 0

    for i in range(num):                               # all features index [0,1,2,...,n-2]
        featList = [sample[i] for sample in data]      # all features with index i in the data set
        uniqueVals = set(featList)                     # Remove Duplicates
        newEntropy = 0.0
        for v in uniqueVals:                           # all values in features i
            subData = spiltDataSet(data,i,v)           # to spilt the data set with (index=i) and (value=v)
            prob = len(subData)/float(len(data))    
            newEntropy += prob*calcShannonEnt(subData) 
        infoGain = baseEntropy - newEntropy
        if infoGain>bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

def majorityCnt(classList):
    """
    :param
        classList: here are 'yes' or 'no'
    :return:
        sortedClassCount[0][0]: final voted class label with the largest number of each class
    """
    classCount = {}
    for v in classList:
        if v not in classCount.keys():
            classCount[v]=0
        classCount[v]+=1
        sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
        return sortedClassCount[0][0]

def createTree(data,labels):
    """
    :param
        data: the given data set
        labels:  description of features
    :return:
        myTree: final tree for making decision
    """
    classList = [sample[-1] for sample in data]           # class label ,here are ['yes','no']
    if classList.count(classList[0]) == len(classList):   # situation1: only one class
        return classList[0]
    if len(data[0])==0:                                   # situation2: only one sample
        return majorityCnt[classList]
    
    bestFeature = chooseBestFeatureToSpilt(data)          # choose best feature
    bestFeatureLabel = labels[bestFeature]                # label of best feature ,here is index :[0,1]
    
    myTree = {bestFeatureLabel:{}}                        # dict to save myTree
    featVals = [sample[bestFeature] for sample in data]   # get all feature values from the best feature
    uniqueVals = set(featVals)
    for v in uniqueVals:                                  # for each unique feature, here all possible values are [0,1]
        subLabels = labels[:]                             
        myTree[bestFeatureLabel][v] = createTree(spiltDataSet(data,bestFeature,v),subLabels)
    return myTree

def classify(inputTree,featLabels,testSample):
    """
    :param 
        inputTree: input tree 
        featLabels: 
        testSample: 
    :return: 
    """
    firstStr = inputTree.keys()[0]    # root node of the tree : [0] means feature (eg:'flippers')
    secondDict = inputTree[firstStr]  # other sub trees in one dict 

    featIndex = featLabels.index(firstStr)  # get index of the feature(current best feature) in root from features labels(description) 

    key = testSample[featIndex]             # value of the best feature of test sample ( here is the key in dict )
    featureVal = secondDict[key]            # return class value
    if isinstance(featureVal,dict):         # can not decide which class the test sample belongs to 
        label = classify(featureVal, featLabels, testSample)  
    else:
        label = featureVal                  # it belongs to 'label' , label in ['yes','no']
    return label

# def storeTree(inputTree, filename):
#     import pickle
#     with open(filename, 'wb') as fw:
#         pickle.dump(inputTree, fw)
# 
# def grabTree(filename):
#     import pickle
#     fr = open(filename,'rb')
#     return pickle.load(fr)

def test():
    data, labels = createData()
    myTree = createTree(data,labels)
    print myTree
    print classify(myTree,labels,[1,1])

if __name__ == '__main__':
    test()

小結

1:程式設計習慣還是得通過多程式設計多思考多總結慢慢提升,重要的事情說三遍,程式設計程式設計再程式設計。

2:從簡單到複雜吧,要能更好更快地解決實際問題,就需要不斷升入瞭解,像決策樹相關的很多演算法,C4.5,CART,GBDT, XGboost,RF等等。

Pycharm學生郵箱註冊

去年年底想在新機器上裝pycharm,結果各種啟用碼都不行,但是之前舊伺服器上可以用就一直沒裝,最近舊伺服器因機房溫度過高得關掉一段時間,所以還得在新機器上裝一個pycharm,這個用起來還是很舒服的。

學生可以用學校郵箱註冊,免費試用pro pycharm一年,於是就整了個學生免費版的,用個一年再說。


忘記了學校郵箱密碼很是尷尬,試過找回挺麻煩的,還好微信校園服務可以收到郵件,一通確認就拿到了學生免費版,哈哈。





















相關文章