如何應對訓練的神經網路不工作?

老司機的詩和遠方發表於2020-04-06

I. 資料集問題

1. 檢查你的輸入資料


檢查饋送到網路的輸入資料是否正確。例如,我不止一次混淆了影像的寬度和高度。有時,我錯誤地令輸入資料全部為零,或者一遍遍地使用同一批資料執行梯度下降。因此列印/顯示若干批量的輸入和目標輸出,並確保它們正確。


2. 嘗試隨機輸入


嘗試傳遞隨機數而不是真實資料,看看錯誤的產生方式是否相同。如果是,說明在某些時候你的網路把資料轉化為了垃圾。試著逐層除錯,並檢視出錯的地方。


3. 檢查資料載入器


你的資料也許很好,但是讀取輸入資料到網路的程式碼可能有問題,所以我們應該在所有操作之前列印第一層的輸入並進行檢查。


4. 確保輸入與輸出相關聯


檢查少許輸入樣本是否有正確的標籤,同樣也確保 shuffling 輸入樣本同樣對輸出標籤有效。


5. 輸入與輸出之間的關係是否太隨機?


相較於隨機的部分(可以認為股票價格也是這種情況),輸入與輸出之間的非隨機部分也許太小,即輸入與輸出的關聯度太低。沒有一個統一的方法來檢測它,因為這要看資料的性質。


6. 資料集中是否有太多的噪音?


我曾經遇到過這種情況,當我從一個食品網站抓取一個影像資料集時,錯誤標籤太多以至於網路無法學習。手動檢查一些輸入樣本並檢視標籤是否大致正確。


7. Shuffle 資料集


如果你的資料集沒有被 shuffle,並且有特定的序列(按標籤排序),這可能給學習帶來不利影響。你可以 shuffle 資料集來避免它,並確保輸入和標籤都被重新排列。


8. 減少類別失衡


一張類別 B 影像和 1000 張類別 A 影像?如果是這種情況,那麼你也許需要平衡你的損失函式或者嘗試其他解決類別失衡的方法。


9. 你有足夠的訓練例項嗎?


如果你在從頭開始訓練一個網路(即不是除錯),你很可能需要大量資料。對於影像分類,每個類別你需要 1000 張影像甚至更多。


10. 確保你採用的批量資料不是單一標籤


這可能發生在排序資料集中(即前 10000 個樣本屬於同一個分類)。可通過 shuffle 資料集輕鬆修復。


11. 縮減批量大小


巨大的批量大小會降低模型的泛化能力(參閱:https://arxiv.org/abs/1609.04836)


II. 資料歸一化/增強


12. 歸一化特徵


你的輸入已經歸一化到零均值和單位方差了嗎?


13. 你是否應用了過量的資料增強?


資料增強有正則化效果(regularizing effect)。過量的資料增強,加上其它形式的正則化(權重 L2,中途退出效應等)可能會導致網路欠擬合(underfit)。


14. 檢查你的預訓練模型的預處理過程


如果你正在使用一個已經預訓練過的模型,確保你現在正在使用的歸一化和預處理與之前訓練模型時的情況相同。例如,一個影像畫素應該在 [0, 1],[-1, 1] 或 [0, 255] 的範圍內嗎?


15. 檢查訓練、驗證、測試集的預處理


CS231n 指出了一個常見的陷阱:「任何預處理資料(例如資料均值)必須只在訓練資料上進行計算,然後再應用到驗證、測試資料中。例如計算均值,然後在整個資料集的每個影像中都減去它,再把資料分發進訓練、驗證、測試集中,這是一個典型的錯誤。」此外,要在每一個樣本或批量(batch)中檢查不同的預處理。


III. 實現的問題


16. 試著解決某一問題的更簡易的版本。


這將會有助於找到問題的根源究竟在哪裡。例如,如果目標輸出是一個物體類別和座標,那就試著把預測結果僅限制在物體類別當中(嘗試去掉座標)。


17.「碰巧」尋找正確的損失


還是來源於 CS231n 的技巧:用小引數進行初始化,不使用正則化。例如,如果我們有 10 個類別,「碰巧」就意味著我們將會在 10% 的時間裡得到正確類別,Softmax 損失是正確類別的負 log 概率: -ln(0.1) = 2.302。然後,試著增加正則化的強度,這樣應該會增加損失。


18. 檢查你的損失函式


如果你執行的是你自己的損失函式,那麼就要檢查錯誤,並且新增單元測試。通常情況下,損失可能會有些不正確,並且損害網路的效能表現。


19. 核實損失輸入


如果你正在使用的是框架提供的損失函式,那麼要確保你傳遞給它的東西是它所期望的。例如,在 PyTorch 中,我會混淆 NLLLoss 和 CrossEntropyLoss,因為一個需要 softmax 輸入,而另一個不需要。


20. 調整損失權重


如果你的損失由幾個更小的損失函式組成,那麼確保它們每一個的相應幅值都是正確的。這可能會涉及到測試損失權重的不同組合。


21. 監控其它指標


有時損失並不是衡量你的網路是否被正確訓練的最佳預測器。如果可以的話,使用其它指標來幫助你,比如精度。


22. 測試任意的自定義層


你自己在網路中實現過任意層嗎?檢查並且複核以確保它們的執行符合預期。


23. 檢查「冷凍」層或變數


檢查你是否無意中阻止了一些層或變數的梯度更新,這些層或變數本來應該是可學的。


24. 擴大網路規模


可能你的網路的表現力不足以採集目標函式。試著加入更多的層,或在全連層中增加更多的隱藏單元。


25. 檢查隱維度誤差


如果你的輸入看上去像(k,H,W)= (64, 64, 64),那麼很容易錯過與錯誤維度相關的誤差。給輸入維度使用一些「奇怪」的數值(例如,每一個維度使用不同的質數),並且檢查它們是如何通過網路傳播的。


26. 探索梯度檢查(Gradient checking)


如果你手動實現梯度下降,梯度檢查會確保你的反向傳播(backpropagation)能像預期中一樣工作。


IV. 訓練問題

27. 一個真正小的資料集


過擬合資料的一個小子集,並確保其工作。例如,僅使用 1 或 2 個例項訓練,並檢視你的網路是否學習了區分它們。然後再訓練每個分類的更多例項。


28. 檢查權重初始化


如果不確定,請使用 Xavier 或 He 初始化。同樣,初始化也許會給你帶來壞的區域性最小值,因此嘗試不同的初始化,看看是否有效。


29. 改變你的超引數


或許你正在使用一個很糟糕的超引數集。如果可行,嘗試一下網格搜尋。


30. 減少正則化


太多的正則化可致使網路嚴重地欠擬合。減少正則化,比如 dropout、批規範、權重/偏差 L2 正則化等。在優秀課程《程式設計人員的深度學習實戰》(Practical Deep Learning For Coders-18 hours of lessons for free)中,Jeremy Howard 建議首先解決欠擬合。這意味著你充分地過擬合資料,並且只有在那時處理過擬合。


31. 給它一些時間


也許你的網路需要更多的時間來訓練,在它能做出有意義的預測之前。如果你的損失在穩步下降,那就再多訓練一會兒。


32. 從訓練模式轉換為測試模式


一些框架的層很像批規範、Dropout,而其他的層在訓練和測試時表現並不同。轉換到適當的模式有助於網路更好地預測。


33. 視覺化訓練


監督每一層的啟用值、權重和更新。確保它們的大小匹配。例如,引數更新的大小(權重和偏差)應該是 1-e3。

考慮視覺化庫,比如 Tensorboard 和 Crayon。緊要時你也可以列印權重/偏差/啟用值。

尋找平均值遠大於 0 的層啟用。嘗試批規範或者 ELUs。


Deeplearning4j 指出了權重和偏差柱狀圖中的期望值:對於權重,一些時間之後這些柱狀圖應該有一個近似高斯的(正常)分佈。對於偏差,這些柱狀圖通常會從 0 開始,並經常以近似高斯(這種情況的一個例外是 LSTM)結束。留意那些向 +/- 無限發散的引數。留意那些變得很大的偏差。這有時可能發生在分類的輸出層,如果類別的分佈不均勻。


檢查層更新,它們應該有一個高斯分佈。


34. 嘗試不同的優化器


優化器的選擇不應當妨礙網路的訓練,除非你選擇了一個特別糟糕的引數。但是,為任務選擇一個合適的優化器非常有助於在最短的時間內獲得最多的訓練。描述你正在使用的演算法的論文應當指定優化器;如果沒有,我傾向於選擇 Adam 或者帶有動量的樸素 SGD。


35. 梯度爆炸、梯度消失


檢查隱蔽層的最新情況,過大的值可能代表梯度爆炸。這時,梯度截斷(Gradient clipping)可能會有所幫助。

檢查隱蔽層的啟用值。Deeplearning4j 中有一個很好的指導方針:「一個好的啟用值標準差大約在 0.5 到 2.0 之間。明顯超過這一範圍可能就代表著啟用值消失或爆炸。」


36. 增加、減少學習速率


低學習速率將會導致你的模型收斂很慢;

高學習速率將會在開始階段減少你的損失,但是可能會導致你很難找到一個好的解決方案。

試著把你當前的學習速率乘以 0.1 或 10。


37. 克服 NaNs


據我所知,在訓練 RNNs 時得到 NaN(Non-a-Number)是一個很大的問題。一些解決它的方法:


減小學習速率,尤其是如果你在前 100 次迭代中就得到了 NaNs。

NaNs 的出現可能是由於用零作了除數,或用零或負數作了自然對數。

Russell Stewart 對如何處理 NaNs 很有心得(http://russellsstewart.com/notes/0.html)。

嘗試逐層評估你的網路,這樣就會看見 NaNs 到底出現在了哪裡。

轉載:https://zhuanlan.zhihu.com/p/28093629?utm_source=qq&utm_medium=social

相關文章