理解「交叉驗證」(Cross Validation)

程式設計師在深圳發表於2019-04-27

交叉驗證是機器學習中常用的一種驗證模型的方法,使用這種方法,你可以

  1. 準確的調整模型的超引數(Hyperparameter),且這組引數對不同的資料,表現相對穩定
  2. 在某些分類場景,你可以同時使用邏輯迴歸、決策樹或聚類等多種演算法建模,當不確定哪種演算法效果更好時,可以使用交叉驗證

除去資料預處理之外,機器學習一般有兩大步驟:訓練(英文術語為 'estimate parameters' 或 'training the algorithm')和測試(英文術語為 'evaluating a method' 或 'testing the algorithm'),一般的,我們將樣本資料分為不同比例的兩部分,其中前 75% 作為訓練資料,剩下的 25% 作為測試資料,然後先用演算法對訓練資料進行擬合,再用測試資料驗證演算法的好壞。

但這樣選擇資料並不能避免偶然性,即在某些情況下,用這最後的 1/4 資料進行測試,剛好能得到比較好的效果,而如果我們改為用前 25% 的資料測試,後 75% 的資料訓練的話,效果卻會大打折扣,如下所示,情況 1 和情況 2 之間只存在資料切分的區別,但情況 1 的測試結果卻要比情況 2 好很多。

理解「交叉驗證」(Cross Validation)

為了降低測試資料產生的偶然性,更好的做法便是採用「交叉驗證」,還是以切分 4 份資料為例,交叉驗證的做法是,對於同一個演算法,同時訓練出 4 個模型,每個模型採用不同的測試資料(例如模型 1 選用第 1 份,模型 2 選用第 2 份,以此類推),在所有模型都完成測試後,再對這 4 個模型的評估結果求平均,便可以得到一個相對穩定且更有說服力的演算法。

舉個具體的例子,假設我們的模型採用決策樹演算法,該演算法有個超引數是樹的深度 height,我們可以將其設為 2,也可以設為 3,但不清楚設哪個數比較好,此時我們就可以使用「交叉驗證」來幫我們決策,首先還是將資料 4 等分,對每一個引數值,我們都訓練 4 次,輸出 4 種可能的測試結果,如下圖所示

理解「交叉驗證」(Cross Validation)

最後,我們根據每個引數下的測試結果,算出它們的平均值

超引數 評估結果的均值
height = 2 (0.68+0.62+0.58+0.72) / 4 = 0.65
height = 3 (0.82+0.60+0.59+0.76) / 4 = 0.69

於是,我們便可以得出,該演算法在 height=3 的情況下效果更好。這個例子說明了我們是如何利用交叉驗證來調超引數的,如文章開頭所說,對於不同演算法的比較,同樣也可以使用這樣的方法。

上文中,這種將資料分為 4 份來做交叉驗證的做法被稱為 4-fold cross validation,實踐中,我們通常使用 10-fold。另外,還有一種只選取 1 條樣本作為測試資料的極端情況,稱為 leave one out,可想而知,這種做法會消耗巨大的計算資源,在生產環境中要謹慎使用。

相關文章