引言
Random Forest Algorithm 是另一個非常流行的 Machine Learning 技術,主要應用於 Regression 和 Classication.
Random Forest 的名字其實是來自於上一節學習的 Decision Tree,它是基於一個資料集, 根據不同的規則劃分, 一級一級建立成一顆樹的形式.
這裡要學的 Random Forest 的內在實現就是把一個資料集拆分成 n 個 tree, n 是可以調節的引數, 用以建立更準確的模型。
正文
引入資料集
下面以手寫數字的分類為例,來學習 Random Forest
from sklearn.datasets import load_digits
digits = load_digits()
拆分測試資料和訓練資料
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size = 0.2)
用 RandomForest 訓練模型
// 引入 RandomForestClassifier
from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier(n_estimators=40)
model.fit(X_train, y_train)
// 輸出
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
max_depth=None, max_features='auto', max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, n_estimators=40, n_jobs=None,
oob_score=False, random_state=None, verbose=0,
warm_start=False)
檢視模型準確度
可以通過微調引數尋找更好的準確度, 引數 n_estimators 就是代表分多少棵樹來建立模型。
model.score(X_test, y_test) // 0.9805555555555555