快來試試 Lookahead 最優化方法啊,調參少、收斂好、速度還快,大牛用了都說好。
最優化方法一直主導著模型的學習過程,沒有最優化器模型也就沒了靈魂。好的最優化方法一直是 ML 社群在積極探索的,它幾乎對任何機器學習任務都會有極大的幫助。
從最開始的批量梯度下降,到後來的隨機梯度下降,然後到 Adam 等一大幫基於適應性學習率的方法,最優化器已經走過了很多年。儘管目前 Adam 差不多已經是預設的最優化器了,但從 17 年開始就有各種研究表示 Adam 還是有一些缺陷的,甚至它的收斂效果在某些環境下比 SGD 還差。
為此,我們期待更好的標準優化器已經很多年了...
最近,來自多倫多大學向量學院的研究者發表了一篇論文,提出了一種新的優化演算法——Lookahead。值得注意的是,該論文的最後作者 Jimmy Ba 也是原來 Adam 演算法的作者,Hinton 老爺子也作為三作參與了該論文,所以作者陣容還是很強大的。
論文地址:https://arxiv.org/abs/1907.08610v1
Lookahead 演算法與已有的方法完全不同,它迭代地更新兩組權重。直觀來說,Lookahead 演算法通過提前觀察另一個優化器生成的「fast weights」序列,來選擇搜尋方向。該研究發現,Lookahead 演算法能夠提升學習穩定性,不僅降低了調參需要的功夫,同時還能提升收斂速度與效果。
實驗證明,Lookahead 演算法的效能顯著優於 SGD 和 Adam,即使 Lookahead 使用的是在 ImageNet、CIFAR-10/100、神經機器翻譯和 Penn Treebank 任務上的預設超引數設定。
最優化器犯了什麼錯
SGD 演算法雖然簡潔,但其在神經網路訓練中的效能堪比高階二階優化方法。儘管 SGD 每一次用小批量算出來的更新方向可能並非那麼精確,但更新多了效果卻出乎意料地好。
一般而言,SGD 各種變體可以分成兩大類:1)自適應學習率機制,如 AdaGrad 和 Adam;2)加速機制,如 Polyak heavyball 和 Nesterov momentum 等。這兩種方法都利用之前累積的梯度資訊實現快速收斂,它們希望借鑑以往的更新方向。但是,要想實現神經網路效能提升,通常需要花銷高昂的超引數調整。
其實很多研究者都發現目前的最優化方法可能有些缺點,不論是 Adam 還是帶動量的 SGD,它們都有難以解決的問題。例如我們目前最常用的 Adam,我們拿它做實驗是沒啥問題的,但要是想追求收斂效能,那麼最好還是用 SGD+Momentum。但使用動量機制又會有新的問題,我們需要調整多個超引數以獲得比較好的效果,不能像 Adam 給個預設的學習率 0.0001 就差不多了。
在 ICLR 2018 的最佳論文 On the Convergence of Adam and Beyond 中,研究者明確指出了 Adam 收斂不好的原因。他們表明在利用歷史梯度的移動均值情況下,模型只能根據短期梯度資訊為每個引數設計學習率,因此也就導致了收斂性表現不太好。
那麼 Hinton 等研究者是怎樣解決這個問題的?他們提出的最優化方法能獲得高收斂效能的同時,還不需要調參嗎?
多倫多大學的「Look ahead」
Michael R. Zhang 和 Geoffrey Hinton 等研究者提出了一種新的最優化方法 Lookahead,該演算法與之前已有的方法都不相同。此外,因為 Lookahead 與其它最優化器是正交的,這意味著我們可以使用 Lookahead 加強已有最優化方法的效能。
如下所示為 Lookahead 的直觀過程,它會維護兩套權重。Lookahead 首先使用內部迴圈中的 SGD 等標準優化器,更新 k 次「Fast weights」,然後以最後一個 Fast weights 的方向更新「slow weights」。如下 Fast Weights 每更新 5 次,Slow weights 就會更新一次。
該研究表明這種更新機制能夠有效地降低方差。研究者發現 Lookahead 對次優超引數沒那麼敏感,因此它對大規模調參的需求沒有那麼強。此外,使用 Lookahead 及其內部優化器(如 SGD 或 Adam),還能實現更快的收斂速度,因此計算開銷也比較小。
研究者在多個實驗中評估 Lookahead 的效果。比如在 CIFAR 和 ImageNet 資料集上訓練分類器,並發現使用 Lookahead 後 ResNet-50 和 ResNet-152 架構都實現了更快速的收斂。
研究者還在 Penn Treebank 資料集上訓練 LSTM 語言模型,在 WMT 2014 English-to-German 資料集上訓練基於 Transformer 的神經機器翻譯模型。在所有任務中,使用 Lookahead 演算法能夠實現更快的收斂、更好的泛化效能,且模型對超引數改變的魯棒性更強。
這些實驗表明 Lookahead 對內部迴圈優化器、fast weight 更新次數以及 slow weights 學習率的改變具備魯棒性。
Lookahead Optimizer 怎麼做
Lookahead 迭代地更新兩組權重:slow weights φ 和 fast weights θ,前者在後者每更新 k 次後更新一次。Lookahead 將任意標準優化演算法 A 作為內部優化器來更新 fast weights。
使用優化器 A 經過 k 次內部優化器更新後,Lookahead 通過在權重空間 θ − φ 中執行線性插值的方式更新 slow weights,方向為最後一個 fast weights。
slow weights 每更新一次,fast weights 將被重置為目前的 slow weights 值。Lookahead 的虛擬碼見下圖 Algorithm 1。
其中最優化器 A 可能是 Adam 或 SGD 等最優化器,內部的 for 迴圈會用常規方法更新 fast weights θ,且每次更新的起始點都是從當前的 slow weights φ 開始。最終模型使用的引數也是慢更新那一套,因此快更新相當於做了一系列實驗,然後慢更新再根據實驗結果選一個比較好的方向,這有點類似 Nesterov Momentum 的思想。
看上去這只是一個小技巧?似乎它應該對實際的引數更新沒什麼重要作用?那麼繼續看看它到底為什麼能 Work。
Lookahead 為什麼能 Work
標準優化方法通常需要謹慎調整學習率,以防止振盪和收斂速度過慢,這在 SGD 設定中更加重要。而 Lookahead 能借助較大的內部迴圈學習率減輕這一問題。
當 Lookahead 向高曲率方向振盪時,fast weights 更新在低曲率方向上快速前進,slow weights 則通過引數插值使振盪平滑。fast weights 和 slow weights 的結合改進了高曲率方向上的學習,降低了方差,並且使得 Lookahead 在實踐中可以實現更快的收斂。
另一方面,Lookahead 還能提升收斂效果。當 fast weights 在極小值周圍慢慢探索時,slow weight 更新促使 Lookahead 激進地探索更優的新區域,從而使測試準確率得到提升。這樣的探索可能是 SGD 更新 20 次也未必能夠到達的水平,因此有效地提升了模型收斂效果。
如上為 ResNet-32 在 CIFAR-100 訓練 100 個 Epoch 後的視覺化結果。在從上圖可以看到模型已經接近最優解了,右上的 SGD 還會慢慢探索比較好的區域,因為當時的梯度已經非常小了。但是右下的 Lookahead 會根據 slow weights(紫色)探索到更好的區域。
當然這裡只是展示了 Lookahead 怎麼做,至於該演算法更新步長、內部學習率等引數怎麼算,讀者可以查閱原論文。此外,Hinton 等研究者還給出了詳細的收斂性分析,感興趣的讀者也可以細細閱讀,畢竟當年 ICLR 2018 最佳論文可是找出了 Adam 原論文收斂性分析的錯誤。
實驗分析
研究人員在一系列深度學習任務上使用 Lookahead 優化器和業內最強的基線方法進行了對比,其中包括在 CIFAR-10/CIFAR-100、ImageNet 上的影像分類任務。此外,研究人員在 Penn Treebank 資料集上訓練了 LSTM 語言模型,也探索了基於 Transformer 的神經機器翻譯模型在 WMT 2014 英語-德語資料集上的表現。對於所有實驗,每個演算法都使用相同數量的訓練資料。
圖 5:不同優化演算法的效能比較。(左)在 CIFAR-100 上的訓練損失。(右)使用不同優化器的 ResNet-18 在 CIFAR 資料集上的驗證準確率。研究者詳細研究了其它優化器的學習率和權重衰減(見論文附錄 C)。Lookahead 和 Polyak 超越了 SGD。
圖 6:ImageNet 的訓練損失。星號表示激進的學習率衰減機制,其中 LR 在迭代 30、48 和 58 次時衰減。右表展示了使用 Lookahead 和 SGD 的 ResNet-50 的驗證準確率。
圖 7:在 Penn Treebank 和 WMT-14 機器翻譯任務上的優化效能。
從這些實驗中,可以得到如下結論:
對於內部優化演算法、k 和 α 的魯棒性:研究人員在 CIFAR 資料集上的實驗表明,Lookahead 可以始終如一地在不同初始超引數設定中實現快速收斂。我們可以看到 Lookahead 可以在基礎優化器上使用更高的學習率進行訓練,且無需對 k 和 α 進行大量調整。
內迴圈和外迴圈評估:研究人員發現,在每個內迴圈中 fast weights 可能會導致任務效能顯著下降——這證實了研究者的分析:內迴圈更新的方差更高。