Machine Learning(13)- Random Forest

Rachel發表於2019-06-12

引言

Machine Learning(13)- Random Forest

Random Forest Algorithm 是另一個非常流行的 Machine Learning 技術,主要應用於 Regression 和 Classication.

Random Forest 的名字其實是來自於上一節學習的 Decision Tree,它是基於一個資料集, 根據不同的規則劃分, 一級一級建立成一顆樹的形式.
這裡要學的 Random Forest 的內在實現就是把一個資料集拆分成 n 個 tree, n 是可以調節的引數, 用以建立更準確的模型。

Machine Learning(13)- Random Forest

正文

引入資料集

下面以手寫數字的分類為例,來學習 Random Forest

from sklearn.datasets import load_digits
digits = load_digits()

拆分測試資料和訓練資料

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size = 0.2)

用 RandomForest 訓練模型

// 引入 RandomForestClassifier
from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier(n_estimators=40)
model.fit(X_train, y_train)

// 輸出
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, n_estimators=40, n_jobs=None,
            oob_score=False, random_state=None, verbose=0,
            warm_start=False)   

檢視模型準確度

可以通過微調引數尋找更好的準確度, 引數 n_estimators 就是代表分多少棵樹來建立模型。

model.score(X_test, y_test) // 0.9805555555555555

相關文章