python3.4之決策樹

勿在浮沙築高臺LS發表於2016-12-22
#!/usr/bin/env python
# coding=utf-8

import numpy as np
from sklearn import tree
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import classification_report
from sklearn.cross_validation import train_test_split
import pydot
from sklearn.externals.six import StringIO

def loadDataSet():
    data = []
    label = []
    with open('D:python/fat.txt') as file:
        for line in file:
            tokens = line.strip().split(' ')
            data.append([float(tk) for tk in tokens[:-1]])
            label.append(tokens[-1])
    x = np.array(data)
    print('x:')
    print(x)
    label = np.array(label)
    y = np.zeros(label.shape)
    y[label == 'fat'] = 1
    print('y:')
    print(y)
    return x, y

def decisionTreeClf():
    x, y = loadDataSet()

    # 拆分資料集和訓練集
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
    print('x_train:');
    print(x_train)
    print('x_test:');
    print(x_test)
    print('y_train:');
    print(y_train)
    print('y_test:');
    print(y_test)
    # 使用資訊熵作為劃分標準
    clf = tree.DecisionTreeClassifier(criterion='entropy')
    print(clf)
    clf.fit(x_train, y_train)
    dot_data = StringIO() 
    with open("iris.dot", 'w') as f: 
        f=tree.export_graphviz(clf, out_file=f)
        tree.export_graphviz(clf, out_file=dot_data)
        graph = pydot.graph_from_dot_data(dot_data.getvalue())  
        graph[0].write_pdf("ex.pdf")  
#         Image(graph.create_png())
    # 列印特徵在分類起到的作用性
    print(clf.feature_importances_)

    # 列印測試結果
    answer = clf.predict(x_train)
    print('x_train:')
    print(x_train)
    print('answer:')
    print(answer)
    print('y_train:')
    print(y_train)
    print('計算正確率:')
    print(np.mean(answer == y_train))

    # 準確率與召回率
    precision, recall, thresholds = precision_recall_curve(y_train, clf.predict(x_train)
)
    answer = clf.predict_proba(x)[:, 1]
    print(classification_report(y, answer, target_names=['thin', 'fat']))

decisionTreeClf()
# print('ll')

資料集fat.txt檔案內容如下:

1.5 50 thin
1.5 60 fat
1.6 40 thin
1.6 60 fat
1.7 60 thin
1.7 80 fat
1.8 60 thin
1.8 90 fat
1.9 70 thin
1.9 80 fat

所需要的python包有:
pygraphviz (1.3.1)
pyparsing (2.1.10)
scikit-learn (0.18.1)
pygraphviz (1.3.1)包是視覺化包。
下載視覺化工具:
graphviz-2.38.msi
百度搜尋安裝視覺化工具。

相關文章