
說明:主要參考 機器學習實戰之決策樹,裡面程式碼的實現和詳細的註釋,是一個關於機器學習實戰的不錯的學習資料,推薦一波。出於程式設計實踐和機器學習演算法梳理的目的,按照自己的程式碼風格重寫該演算法,在實現的過程中也很有助於自己的思考。為方便下次看時能快速理解便通過截圖的方式擷取了個人認為比較關鍵的內容,推薦看原連結,自己在程式碼實現過程中會留下一些思考,也歡迎交流學習。



下面兩個圖是自己以前看決策樹的時候整的PPT裡面的兩頁,主要是說的是ID3和相關概念,至於C4.5,CART,GBDT, RandomForest等內容就不貼上來了,ID3的原理還是很簡單的,可以找到很多其他資料,這裡呢主要還是側重於程式設計實踐。



#  __author__ = 'czx'
# coding=utf-8
    ID3 Algorithm for fish classification task .
from numpy import *
from math import log

def createData():
        data: including feature values and class values
        labels: description of features
    data = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [1, 0, 'no'], [0, 1, 'no']]
    labels = ['no surfing', 'flippers']
    return data, labels

def calcShannonEnt( data ):
        data: given data set
        shannonent: shannon entropy of the given data set
    num = len(data)                             # data set size
    labelcount = {}                             # count number of each class ['yes','no']
    for sample in data:
        templabel = sample[-1]                  # class label of sample in data ['yes','no']
        if templabel not in labelcount.keys():  # add to dict
            labelcount[templabel] = 0           # initial value is 0
        labelcount[templabel] += 1              # count
    shannonent = 0.0                            # initial shannon entropy
    for key in labelcount:                      # for all classes
        prob = float(labelcount[key])/num       # Pi = Ni/N
        shannonent -= prob * log(prob, 2)       # Ent = Addup(Pi) (i=2) ,bacause classes:['yes','no']
    return shannonent                           # shannon entropy of the given data

def spiltDataSet(data,index,value):
        data: the given data set
        index: index of selected feature
        value: the selected value to spilt the data set
        resData: result of spilt data set  
    resData = []
    for sample in data:                           # for all samples in data set
        # Mention that ID3 algorithm can only handle features with discrete values 
        if sample[index] == value:                # the selected feature value of sample
            spiltSample = sample[:index]          # first index features
            spiltSample.extend(sample[index+1:])  # last all features except feature[index]
    return resData

def chooseBestFeatureToSpilt(data):
        data: the given data set
        bestFeature: the index of best feature which has best info gain
    num = len(data[0])-1                               # all feature index [final column is class value]
    baseEntropy = calcShannonEnt(data)                 # initial shannon entropy of biggest data set
    bestInfoGain = 0.0
    bestFeature = 0

    for i in range(num):                               # all features index [0,1,2,...,n-2]
        featList = [sample[i] for sample in data]      # all features with index i in the data set
        uniqueVals = set(featList)                     # Remove Duplicates
        newEntropy = 0.0
        for v in uniqueVals:                           # all values in features i
            subData = spiltDataSet(data,i,v)           # to spilt the data set with (index=i) and (value=v)
            prob = len(subData)/float(len(data))    
            newEntropy += prob*calcShannonEnt(subData) 
        infoGain = baseEntropy - newEntropy
        if infoGain>bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

def majorityCnt(classList):
        classList: here are 'yes' or 'no'
        sortedClassCount[0][0]: final voted class label with the largest number of each class
    classCount = {}
    for v in classList:
        if v not in classCount.keys():
        sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
        return sortedClassCount[0][0]

def createTree(data,labels):
        data: the given data set
        labels:  description of features
        myTree: final tree for making decision
    classList = [sample[-1] for sample in data]           # class label ,here are ['yes','no']
    if classList.count(classList[0]) == len(classList):   # situation1: only one class
        return classList[0]
    if len(data[0])==0:                                   # situation2: only one sample
        return majorityCnt[classList]
    bestFeature = chooseBestFeatureToSpilt(data)          # choose best feature
    bestFeatureLabel = labels[bestFeature]                # label of best feature ,here is index :[0,1]
    myTree = {bestFeatureLabel:{}}                        # dict to save myTree
    featVals = [sample[bestFeature] for sample in data]   # get all feature values from the best feature
    uniqueVals = set(featVals)
    for v in uniqueVals:                                  # for each unique feature, here all possible values are [0,1]
        subLabels = labels[:]                             
        myTree[bestFeatureLabel][v] = createTree(spiltDataSet(data,bestFeature,v),subLabels)
    return myTree

def classify(inputTree,featLabels,testSample):
        inputTree: input tree 
    firstStr = inputTree.keys()[0]    # root node of the tree : [0] means feature (eg:'flippers')
    secondDict = inputTree[firstStr]  # other sub trees in one dict 

    featIndex = featLabels.index(firstStr)  # get index of the feature(current best feature) in root from features labels(description) 

    key = testSample[featIndex]             # value of the best feature of test sample ( here is the key in dict )
    featureVal = secondDict[key]            # return class value
    if isinstance(featureVal,dict):         # can not decide which class the test sample belongs to 
        label = classify(featureVal, featLabels, testSample)  
        label = featureVal                  # it belongs to 'label' , label in ['yes','no']
    return label

# def storeTree(inputTree, filename):
#     import pickle
#     with open(filename, 'wb') as fw:
#         pickle.dump(inputTree, fw)
# def grabTree(filename):
#     import pickle
#     fr = open(filename,'rb')
#     return pickle.load(fr)

def test():
    data, labels = createData()
    myTree = createTree(data,labels)
    print myTree
    print classify(myTree,labels,[1,1])

2:從簡單到複雜吧,要能更好更快地解決實際問題,就需要不斷升入瞭解,像決策樹相關的很多演算法,C4.5,CART,GBDT, XGboost,RF等等。



