模型融合——stacking原理與實現

魚與魚發表於2022-02-14

一般提升模型效果從兩個大的方面入手

資料層面:資料增強、特徵工程等

模型層面:調參,模型融合

模型融合:通過融合多個不同的模型,可能提升機器學習的效能。這一方法在各種機器學習比賽中廣泛應用, 也是在比賽的攻堅時刻衝刺Top的關鍵。而融合模型往往又可以從模型結果,模型自身,樣本集等不同的角度進行融合。

模型融合是後期一個重要的環節,大體來說有如下的型別方式:

  • 加權融合(投票、平均)

    硬投票

    image-20220214103407094

    軟投票

    image-20220214103422542

  • boosting/bagging(整合學習)

  • stacking/blending

本文主要介紹stacking/blending方法的原理,及其實際應用

Stacking模型本質上是一種分層的結構,這裡簡單起見,只分析二級Stacking.假設我們有3個基模型M1、M2、M3。[1]

  1. 基模型M1,對訓練集train訓練,然後在訓練集和測試集預測,分別得到P1,T1。同理,得到P2,T2;P3,T3

    \[\begin{pmatrix} \vdots\\ P1\\ \vdots\\ \end{pmatrix} \begin{pmatrix} \vdots\\ T1\\ \vdots\\ \end{pmatrix}, \begin{pmatrix} \vdots\\ P2\\ \vdots\\ \end{pmatrix} \begin{pmatrix} \vdots\\ T2\\ \vdots\\ \end{pmatrix}, \begin{pmatrix} \vdots\\ P3\\ \vdots\\ \end{pmatrix} \begin{pmatrix} \vdots\\ T3\\ \vdots\\ \end{pmatrix} \]

  2. 分別把P1,P2,P3以及T1,T2,T3合併,得到一個新的訓練集和測試集train2,test2.

    image-20220214114151806

  3. 再用第二層的模型M4訓練train2,預測test2,得到最終的標籤列。

    image-20220214114243425

注意:

用整個訓練集訓練的模型反過來去預測訓練集的標籤,毫無疑問過擬合是非常非常嚴重的,因此現在的問題變成了如何在解決過擬合的前提下得到P1、P2、P3,這就變成了熟悉的節奏——K折交叉驗證。

image-20220214111331739

上圖的模型1-5其實是一個模型在不同折下訓練。

最終的程式碼是兩層迴圈,第一層迴圈控制基模型的數目,每一個基模型要這樣去得到P1,T1,第二層迴圈控制的是交叉驗證的次數K,對每一個基模型,會訓練K次最後拼接得到P1,取平均得到T1。

python實現[2]

### 6折stacking
n_folds = 6
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=1)
for i,clf in enumerate(clfs):
#     print("分類器:{}".format(clf))
    X_stack_test_n = np.zeros((X_test.shape[0], n_folds))
    for j,(train_index,test_index) in enumerate(skf.split(X_train,y_train)):
                tr_x = X_train[train_index]
                tr_y = y_train[train_index]
                clf.fit(tr_x, tr_y)
                #生成stacking訓練資料集
                X_train_stack [test_index, i] = clf.predict_proba(X_train[test_index])[:,1]
                X_stack_test_n[:,j] = clf.predict_proba(X_test)[:,1]
    #生成stacking測試資料集
    X_test_stack[:,i] = X_stack_test_n.mean(axis=1)

理論介紹推薦閱讀[1],實現部分可以閱讀[2]

references

【1】【機器學習】模型融合方法概述. https://zhuanlan.zhihu.com/p/25836678

【2】Kaggle提升模型效能的超強殺招Stacking——機器學習模型融合. https://zhuanlan.zhihu.com/p/107655409

相關文章