大幅減少訓練迭代次數,提高泛化能力:IBM提出「新版Dropout」

dicksonjyl560101發表於2019-06-11


Dropout 可以提高深度神經網路的泛化能力,因此被廣泛應用於各種 DNN 任務中。訓練時,dropout 會透過隨機忽略一部分神經元來防止過擬合。本文基於此提出了 multi-sample dropout,這種改進版的 dropout 既能加快訓練速度,又能提高泛化能力。

Dropout (Hinton et al.[2012]) 是提高深度神經網路(DNN)泛化能力的主要正則化技術之一。由於其簡單、高效的特點,傳統 dropout 及其他類似技術廣泛應用於當前的神經網路中。dropout 會在每輪訓練中隨機忽略(即 drop)50% 的神經元,以避免過擬合的發生。如此一來,神經元之間無法相互依賴,從而保證了神經網路的泛化能力。在推理過程中會用到所有的神經元,因此所有的資訊都被保留;但輸出值會乘 0.5,使平均值與訓練時間一致。這種推理網路可以看作是訓練過程中隨機生成的多個子網路的集合。Dropout 的成功推動了許多技術的發展,這些技術使用各種方法來選擇要忽略的資訊。例如,DropConnect (Wan et al. [2013]) 隨機忽略神經元之間的部分連線,而不是神經元。

本文闡述的也是一種 dropout 技術的變形——multi-sample dropout。傳統 dropout 在每輪訓練時會從輸入中隨機選擇一組樣本(稱之為 dropout 樣本),而 multi-sample dropout 會建立多個 dropout 樣本,然後平均所有樣本的損失,從而得到最終的損失。這種方法只要在 dropout 層後複製部分訓練網路,並在這些複製的全連線層之間共享權重就可以了,無需新運算子。

透過綜合 M 個 dropout 樣本的損失來更新網路引數,使得最終損失比任何一個 dropout 樣本的損失都低。這樣做的效果類似於對一個 minibatch 中的每個輸入重複訓練 M 次。因此,它大大減少了訓練迭代次數。

實驗結果表明,在基於 ImageNet、CIFAR-10、CIFAR-100 和 SVHN 資料集的影像分類任務中,使用 multi-sample dropout 可以大大減少訓練迭代次數,從而大幅加快訓練速度。因為大部分運算發生在 dropout 層之前的卷積層中,Multi-sample dropout 並不會重複這些計算,所以對每次迭代的計算成本影響不大。實驗表明,multi-sample dropout 還可以降低訓練集和驗證集的錯誤率和損失。

Multi-Sample Dropout

圖 1 是一個簡單的 multi-sample dropout 例項,這個例項使用了 2 個 dropout 樣本。該例項中只使用了現有的深度學習框架和常見的運算子。如圖所示,每個 dropout 樣本都複製了原網路中 dropout 層和 dropout 後的幾層,圖中例項複製了「dropout」、「fully connected」和「softmax + loss func」層。在 dropout 層中,每個 dropout 樣本使用不同的掩碼來使其神經元子集不同,但複製的全連線層之間會共享引數(即連線權重),然後利用相同的損失函式,如交叉熵,計算每個 dropout 樣本的損失,並對所有 dropout 樣本的損失值進行平均,就可以得到最終的損失值。該方法以最後的損失值作為最佳化訓練的目標函式,以最後一個全連線層輸出中的最大值的類標籤作為預測標籤。當 dropout 應用於網路尾段時,由於重複操作而增加的訓練時間並不多。值得注意的是,multi-sample dropout 中 dropout 樣本的數量可以是任意的,而圖 1 中展示了有兩個 dropout 樣本的例項。


大幅減少訓練迭代次數,提高泛化能力:IBM提出「新版Dropout」

圖 1:傳統 dropout(左)與 multi-sample dropout(右)

神經元在推理過程中是不會被忽略的。只計算一個 dropout 樣本的損失是因為 dropout 樣本在推理時是一樣的,這樣做可以對網路進行修剪以消除冗餘計算。要注意的是,在推理時使用所有的 dropout 樣本並不會嚴重影響預測效能,只是稍微增加了推理時間的計算成本。

為什麼 Multi-Sample Dropout 可以加速訓練

直觀來說,帶有 M 個 dropout 樣本的 multi-sample dropout 的效果類似於透過複製 minibatch 中每個樣本 M 次來將這個 minibatch 擴大 M 倍。例如,如果一個 minibatch 由兩個資料樣本(A, B)組成,使用有 2 個 dropout 樣本的 multi-sample dropout 就如同使用傳統 dropout 加一個由(A, A, B, B)組成的 minibatch 一樣。其中 dropout 對 minibatch 中的每個樣本應用不同的掩碼。透過複製樣本來增大 minibatch 使得計算時間增加了近 M 倍,這也使得這種方式並沒有多少實際意義。相比之下,multi-sample dropout 只重複了 dropout 後的操作,所以在不顯著增加計算成本的情況下也可以獲得相似的收益。由於啟用函式的非線性,傳統方法(增大版 minibatch 與傳統 dropout 的組合)和 multi-sample dropout 可能不會給出完全相同的結果。然而,如實驗結果所示,迭代次數的減少還是顯示出了 multi-sample dropout 的加速效果。

實驗

Multi-Sample Dropout 帶來的改進

圖 2 展示了三種情況下(傳統 dropout、multi-sample dropout 和不使用 dropout 進行訓練)的訓練損失和驗證集誤差隨訓練時間的變化趨勢。本例中 multi-sample dropout 使用了 8 個 dropout 樣本。從圖中可以看出,對於所有資料集來說,multi-sample dropout 比傳統 dropout 更快。

大幅減少訓練迭代次數,提高泛化能力:IBM提出「新版Dropout」

圖 2:傳統 dropout 和 multi-sample dropout 的訓練集損失和驗證集誤差隨訓練時間的變化趨勢。multi-sample dropout 展現了更快的訓練速度和更低的錯誤率。

表 1 總結了最終的訓練集損失、訓練集錯誤率和驗證集錯誤率。


大幅減少訓練迭代次數,提高泛化能力:IBM提出「新版Dropout」

表 1:傳統 dropout 和 multi-sample dropout 的訓練集損失、訓練集錯誤率和驗證集錯誤率。multi-sample dropou 與傳統 dropout 相比有更低的損失和錯誤率。

引數對效能的影響

圖 3 (a) 和圖 3 (b) 比較了不同數量 dropout 樣本和不同的 epoch 下在 CIFAR-100 上的訓練集損失和驗證集誤差。使用更多的 dropout 樣本加快了訓練的進度。當 dropout 樣本多達 64 個時,dropout 樣本的數量與訓練損失的加速之間顯現出明顯的關係。對於圖 3(b) 所示的驗證集誤差,dropout 樣本在大於 8 個時,再增加 dropout 樣本數量不再能帶來顯著的收益。

大幅減少訓練迭代次數,提高泛化能力:IBM提出「新版Dropout」

圖 3:不同數量的 dropout 樣本在訓練過程中的訓練集損失和驗證集誤差。

大幅減少訓練迭代次數,提高泛化能力:IBM提出「新版Dropout」

表 2:不同 dropout 樣本數量下與傳統 dropout 的迭代時間比較。增加 dropout 樣本的數量會增加迭代時間。由於記憶體不足,無法執行有 16 個 dropout 示例的 VGG16。

大幅減少訓練迭代次數,提高泛化能力:IBM提出「新版Dropout」

圖 4:不同數量的 dropout 樣本訓練後的損失和錯誤率。

大幅減少訓練迭代次數,提高泛化能力:IBM提出「新版Dropout」

圖 5:(a) 驗證錯誤率,(b) 不同 dropout 率下的 multi-sample dropout 和傳統 dropout 的訓練損失趨勢。其中 35% 的 dropout 率表示兩個 dropout 層分別使用 40% 和 30%。

大幅減少訓練迭代次數,提高泛化能力:IBM提出「新版Dropout」

圖 6:有水平翻轉(增加 dropout 樣本多樣性)和沒有水平翻轉時訓練損失的比較。x 軸表示 epoch 數。


為什麼 multi-sample dropout 很高效

如前所述,dropout 樣本數為 M 的 multi-sample dropout 效能類似於透過複製 minibatch 中的每個樣本 M 次來將 minibatch 的大小擴大 M 倍。這也是 multi-sample dropout 可以加速訓練的主要原因。圖 7 可以說明這一點。

大幅減少訓練迭代次數,提高泛化能力:IBM提出「新版Dropout」

圖 7:傳統 dropout 加資料複製後的 minibatch 與 multi-sample dropout 的比較。x 軸表示 epoch 數。為了公平的比較,研究者在 multi-sample dropout 中沒有使用會增加樣本多樣性的橫向翻轉和零填充。


論文連結:


來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/29829936/viewspace-2647267/,如需轉載,請註明出處,否則將追究法律責任。

相關文章