深度學習與XGBoost在小資料集上的測評,你怎麼看?

思源發表於2017-06-26
近來,部分機器學習從業者對深度學習不能訓練小資料集這一觀點表示懷疑,他們普遍認為如果深度學習經過優良的調參,那麼就不會出現過擬合和過訓練情況,也就能較好地從小資料集學習不錯的模型。在本文中,Max Brggen 在多個小資料集對神經網路和 XGBoost 進行了對比,並表明 ANN 在小資料集可以得到和 XGBoost 相媲美的結果。

模型原始碼:https://gist.github.com/maxberggren/b3ae92b26fd7039ccf22d937d49b1dfd

Andrew Beam 曾展示目前的神經網路方法如果有很好的調參是能夠在小資料集上取得好結果的。如果你目前正在使用正則化方法,那麼人工神經網路完全有可能在小資料集上取代傳統的統計機器學習方法。下面讓我們在基準資料集上比較這些演算法。

深度學習與XGBoost在小資料集上的測評,你怎麼看?

先從從 iris 資料集開始,因為我們可以很容易地使用 pandas read_csv 函式從網上讀取資料集。

深度學習與XGBoost在小資料集上的測評,你怎麼看?

注意,上述程式碼塊的資料集讀取地址(顯示不全)為:

「https://gist.githubusercontent.com/curran/a08a1080b88344b0c8a7/raw/d546eaee765268bf2f487608c537c05e22e4b221/iris.csv」

該資料集只有三個類別共計 150 個資料點,它是一個很小的資料集。

深度學習與XGBoost在小資料集上的測評,你怎麼看?

我們可以從 Pandas 資料框架中建立特徵矩陣 X 和目標向量 y。因為 ANN 的特徵矩陣需要歸一化,所以先要進行最小最大縮放。

深度學習與XGBoost在小資料集上的測評,你怎麼看?

我們將資料集分割為訓練集和測試集。

深度學習與XGBoost在小資料集上的測評,你怎麼看?

匯入一些 keras 庫的函式(如果沒有安裝 keras,可以鍵入 pip install keras)。

深度學習與XGBoost在小資料集上的測評,你怎麼看?

設定神經網路的深度為 3 層,每一層的寬度為 128 個神經元。這並沒有什麼特別的,甚至都不一定能算做深度學習,但該網路在每層之間使用了一些 dropout 幫助減少過擬合現象。

Adam 最佳化方法的學習率可能在其他資料集還需要微調,但是在該資料集保留 0.001 效果就已經十分不錯了。

深度學習與XGBoost在小資料集上的測評,你怎麼看?

EarlyStopping 函式在驗證集精度不再提高的時候可以幫助我們終止訓練,同樣這也會幫助我們避免過擬合。同時我們還需要在出現過擬合之前儲存模型,ModelCheckpoints 函式可以讓我們在驗證集精度出現下降前儲存最優模型。

深度學習與XGBoost在小資料集上的測評,你怎麼看?

深度學習與XGBoost在小資料集上的測評,你怎麼看?

現在我們可以在測試集上評估效能,下面的混淆矩陣展示了測試集所有預測值和真實值的分佈。

深度學習與XGBoost在小資料集上的測評,你怎麼看?

深度學習與XGBoost在小資料集上的測評,你怎麼看?

實際上該結果極其優秀。接下來我們透過 sklearn API 構建 xgboost(conda install xgboost) 模型。

尋找優良的超引數對貝葉斯方法來說是很好的任務,它能在沒有任何梯度的情況下以有效的方式評估替代方案。而像 GridSearch 那樣的方法需要大量的時間,因此我們反而給它一個引數空間和「預算」。所以該方法會在這些條件約束下最有效地評估 XGBoost 超引數。

深度學習與XGBoost在小資料集上的測評,你怎麼看?

因此我們使用的是 skopt (pip install scikit-optimize)。我們給定 50 次迭代來挖掘超引數空間。

深度學習與XGBoost在小資料集上的測評,你怎麼看?

Best accuracy score = 0.96
Best parameters = {'colsample_bytree': 1.0,
'learning_rate': 0.10000000000000001, 'min_child_weight': 5,
'n_estimators': 45, 'subsample': 1, 'max_depth': 5}

深度學習與XGBoost在小資料集上的測評,你怎麼看?

深度學習與XGBoost在小資料集上的測評,你怎麼看?

下面我們需要固定這些超引數並在測試集上評估模型,該測試集和 Keras 使用的測試集是一樣的。

深度學習與XGBoost在小資料集上的測評,你怎麼看?

深度學習與XGBoost在小資料集上的測評,你怎麼看?

在這個基準資料集中,並不太深的神經網路全部預測正確,而 XGBoost 預測錯了三個。當然如果我們改變種子並且再執行一次,XGBoost 演算法也可能會完全正確,所以這一結果並不能說明神經網路就要比提升方法好,我們也不需要進一步解讀。

下面我們將以上的程式碼進一步推廣到一般情況,因此我們能嵌入任何選定的資料集,並對比兩種方法的測試集精度和可能存在困難的任務。當我們在處理程式碼時,我們可以在精度統計值上新增一個 boostrap 以瞭解不確定性大小。

完整的程式碼可以在 Github 檢視:https://gist.github.com/maxberggren/b3ae92b26fd7039ccf22d937d49b1dfd

Telecom churn 資料集(n=2325)

深度學習與XGBoost在小資料集上的測評,你怎麼看?

資料集:https://community.watsonanalytics.com/wp-content/uploads/2015/03/WA_Fn-UseC_-Telco-Customer-Churn.csv?cm_mc_uid=06267660176214972094054&cm_mc_sid_50200000=1497209405&cm_mc_sid_52640000=1497209405

ANN

深度學習與XGBoost在小資料集上的測評,你怎麼看?

XGBoost

深度學習與XGBoost在小資料集上的測評,你怎麼看?

Churn 是一個更加困難的任務,但兩種方法都做得挺好。

三種紅酒資料集(n=59)

深度學習與XGBoost在小資料集上的測評,你怎麼看?

資料集:https://gist.githubusercontent.com/tijptjik/9408623/raw/b237fa5848349a14a14e5d4107dc7897c21951f5/wine.csv

ANN

深度學習與XGBoost在小資料集上的測評,你怎麼看?

XGBoost

深度學習與XGBoost在小資料集上的測評,你怎麼看?

這是一個非常簡單的資料集,這兩種方法都沒有出現異常,因為樣本空間實在是太小了,所以 boostrap 基本上沒起什麼作用。

德國人資信資料(n=1000)

深度學習與XGBoost在小資料集上的測評,你怎麼看?

資料集:https://onlinecourses.science.psu.edu/stat857/sites/onlinecourses.science.psu.edu.stat857/files/german_credit.csv

ANN

深度學習與XGBoost在小資料集上的測評,你怎麼看?

XGBoost

深度學習與XGBoost在小資料集上的測評,你怎麼看?

所以從上面來看,ANN 有時能得到最好的效能,而 XGBoost 有時也能得到最好的效能。所以我們可以認為只要 ANN 控制了過擬合和過訓練,它就能擁有優良的表現,至少是能和 XGBoost 相匹配的效能。

XGBoost 的調參確實需要很多時間,也很困難,但 ANN 基本不用花時間去做這些事情,所以讓我們拭目以待 ANN 到底是否會在小資料集上也會有大的發展。

相關文章