實踐心得:從讀論文到復現到為開源貢獻程式碼

阿里云云棲社群發表於2018-05-18

摘要:
本文講述了從在fast.ai庫中讀論文,到根據論文複製實驗並做出改進,並將改進後的開原始碼放入fast.ai庫中。

eaa5115e520d0524d76d32294c5384188a811edf

介紹

去年我發現
MOOC網上有大量的Keras和TensorKow教學視訊,之後我從零開始學習及參加一些Kaggle比賽,並在二月底獲得了fast.ai國際獎學金。去年秋天,當我在全力學習PyTorch時,我在feed中發現了一條關於新論文的推文:“平均權重會產生更廣泛的區域性優化和更好的泛化。”具體來說,就是我看到一條如何將其新增到fast.ai庫的推文。現在我也參與到這項研究。

在一名軟體工程師的職業生涯中,我發現學習一門新技術最好的方法是將它應用到具體的專案。所以我認為這不僅可以練習提高我的PyTorch能力,還能更好的熟悉fast.ai庫,也能提高我閱讀和理解深度學習論文的能力。

作者發表了使用隨機加權平均(SWA)訓練VGG16和預啟用的Resnet-110模型時獲得的改進。對於VGG網路結構,SWA將錯誤率從6.58%降低到6.28%,相對提高了4.5%,而Resnet模型則更明顯,將誤差從4.47%減少到3.85%,相對提高了13.9%。

論文

背景

隨機加權平均(SWA)方法來自於整合。整合是用於提高機器學習模型效能的流行的技術。例如,ensemble演算法獲得了Nekix獎,因為Netkix過於複雜不適用於實際生產,而在像Kaggle這樣的競爭平臺上,整合最終效能表現結果可以遠超單個模型。

最簡單的方式為,整合可以對不同初始化的模型的若干副本進行訓練,並將對副本的預測平均以得到整體的預測。但是這種方法的缺點是必須承擔n個不同副本的成本。研究人員提出快照整合(Snapshot Ensembles)方法。改方法是對一個模型進行訓練,並將模型收斂到幾個區域性最優點,儲存每個最優點的權重。這樣一個單一的訓練就可以產生n個不同的模型,將這些預測平均就能預測出整體。

在發表SWA論文之前,作者曾發表過快速幾何整合(FGE)方法的論文,改方法改進了快照整合的結果,FGE方法為“區域性最優能通過近乎恆定損耗的簡單曲線連線起來”也就是說,通過FGE作者能夠發現損耗曲面中的曲線具有理想的特性,以及通過這些曲線整合模型。

在SWA論文中,作者提供了SWA接近FGE的證據。然而,SWA比FGE的好處是推理成本較低 。FGE需要產生n個模型的預測結果,而對於SWA而言,最終只需要一個模型,因此推斷可以更快。

演算法

SWA演算法的工作原理相對簡單。首先製作你正在訓練的模型的副本,以便用於跟蹤平均權重。在完成epoch訓練後,通過以下公式更新副本的權重:

158093d4ff0bc8f9ee265fb8eaf938faed5d7d15

其中n_models是已經包含在平均值中的模型數量,w_swa表示副本的權重,w表示正在訓練的模型的權重。這相當於在每個epoch訓練時期結束時儲存模型的執行平均值。這就是該演算法的精髓,但論文還介紹了一些細節,首頁作者制定了具體的學習率計劃,以確保SGD在開始平均模型時就能夠找到出最優點。其次,對網路進行預訓練以達到開始時就有一定數量的epochs,而不是一開始就追蹤平均值。另外,如果使用週期性學習率,那麼需要在每個週期結束時儲存平均值,而不是在每個epoch後。

尋找更廣泛的最優點

SWA的演算法的工作方式,作者提供了證據,證明與SGD相比,它能使模型達到更廣泛的區域性最優,從而能夠提高模型的泛化能力,因為訓練損失和測試資料可能不完全一致。因此,對訓練資料進行更廣泛的優化使得模型對測試資料進行優化。

9bbc400da5841adc1697838d6f28f90a75ce6a3c

圖三的一部分

由圖可得,訓練損失(左)和測試錯誤(右)相似但不完全相同。例如,最右邊的
X處於訓練損失表面的最佳點,但距離最優測試誤差有一定距離。正是這些差異能更容易的尋找更廣泛的最優點,這更可能成為訓練和測試損失的最佳點。

作者提出觀點:
SWA可以找到更廣泛的最優點。並在論文Optima Width章節中通過實驗給出了證據,將損失作為給定方向上的Optima距離的函式,來比較SGD和SWA能夠發現的最優點寬度。作者對10個不同的方向進行了取樣,並測量了用SGD和SWA對CIFAR-10進行訓練的Preactivation Resnet的損失,結果如下:

66920ea89012cd7067788ae39df5d7ddca00b260

圖4:“測試誤差...作為隨機射線上的點函式,起始於CIFAR-100上預啟用ResNet-110的SWA(藍色)和SGD(綠色)解決方案。”

圖中資料提供了證據,表明SWA發現的optima比SGD所發現的更廣泛,因為它與SWA最優的距離比增加同樣數量的測試錯誤的距離更大。例如,要達到50%的測試誤差,你必須從距離SGD的最佳距離為30,而SGD為50。

實驗

作者進行了大量的實驗來驗證SWA方法在不同的資料集和模型架構上的有效性。首先,我將詳細描述為了實現該演算法做的實驗設定,然後講解一些關鍵結果。

使用VGG16和預啟用的Resnet-110體系結構在CIFAR-10上進行了複製實驗。每個體系結構都有一定的預算,以表示僅使用SGD +動量來訓練模型收斂所需的時間數。VGG預算為200,而Resnet則為150。然後,為了測試SWA,模型用SGD +動力培訓約75%的預算,然後用SWA進行額外的epochs訓練,達到原始預算的1、1.25和1.5倍。對每個測試訓練了三個模型,並報告平均值和標準偏差。

除了對CIFAR-10的實驗外,作者還對CIFAR-100進行了類似的實驗。他們還在ImageNet上測試了預訓練模型,使用SWA執行了10個epochs,並發現在預訓練的ResNet-50、ResNet152和DenseNet-161的精度提高了。最後,作者通過使用固定學習速率的SWA,成功地從scratch中訓練了一個寬的ResNet-28-10。

實現

閱讀並理解該論文後,我嘗試在fast.ai庫中找出哪個位置新增程式碼能夠使SWA正常工作。該位置已經找到了,因為fast.ai庫提供了新增自定義回撥的功能。如果我用每個epoch結束時呼叫的hook來寫回撥,那麼就能在適當的時間更新權重的執行平均值。這是結束的程式碼:

82df6b30728983878ba9f2eedc2dd25b1c9bef81


回撥採用三個引數:model、swa_model和swa_start。前兩個是我們正在訓練的模型,以及我們將用來儲存加權平均的模型副本。swa_start引數是平均開始的時間,因為在論文中,模型總是在開始跟蹤平均權重之前,用SGD+動量對一定數量的epochs進行訓練。

從這裡你可以看到SWA回撥如何將演算法從檔案轉換成PyTorch程式碼。在SWA開始的epoch中,我們將更新引數的執行平均值,並增加平均值中包含的模型數量。

在SWA模型進行推斷前,我們還需要用包含程式碼修復batchnorm的運算平均值。batchnorm層通常在訓練期間計算這些執行統計資料,但由於模型的權重是作為其他模型的平均值計算的,所以這些執行統計資料對於SWA模型是錯誤的,因此需要再次單次傳遞資料讓batchnorm層計算正確的執行統計資料。修復程式碼如下:

9c0b5e08c639eeb8f2611ceb37d40626547cd923

測試

測試非常重要,但是在機器學習程式碼中應用單元測試是很困難的,因為有一些不確定的因素或者測試的狀態需要較長時間。為了確保所做工作實際上是有效的,我做了兩個測試,一個是
“功能”測試,它們是較小的程式碼塊,通常執行在比較簡單的模型上,旨在回答:“這個功能是否按照我的想法實現了?”例如,一項功能測試表明,在經過幾個階段的訓練後,SWA模型實際上等於所有SGD模型引數的平均值:

e58912790f2d92ea264e5997ba9bd76f572637d3

這些測試通常在30秒內就能執行完成,所以在編寫實現程式碼遇到問題時能快速提醒我。由於fast.ai庫的開發速度非常快,這些測試還能在試圖解決master分支合併問題時快速識別問題。

第二個測試為“實驗”測試。它的目的是回答:“如果我用自己的實現和fast.ai庫重新進行論文中的實驗,我是否能觀察到與論文相同的結果?”每次我實現一個功能就會執行這個測試,以確定SWA是否對庫做出有用的貢獻。實驗測試要比功能測試花費的時間長,但能確保一切都按預期執行。

最後我可以複述論文的結果
-隨機權重平均確實在CIFAR-10上產生了比一般SGD更高的準確性,並且隨著訓練時期的增加,這種改善通常會增加。正如下表所示,我所有的結果都比原始論文結果更準確。其中一個因素可能是資料增強的方式——對於CIFAR-10,通過將每個影象填充4個畫素並隨機裁剪進行增強,並且我發現fast.ai預設使用不同型別的填充(rekection填充)。然而,可以清楚地看到SWA改善超過SGD +momentum的模式。

ee0edd049299c9e4adee6c2a2999919aa0a6ba4b

原始論文的結果

e566667698766ed46607c2b90de74bca4ce443cc

我的結果獲取測試程式碼請點選程式碼

結論

我對這個專案的最終結果非常滿意,因為我從最前沿的研究論文中複製了一個實驗,併為機器學習開原始碼做出了自己的第一個貢獻。我想鼓勵大家下載
fast.ai庫,並嘗試一下SWA吧!

本文由阿里云云棲社群組織翻譯。

原文連結

本文為雲棲社群原創內容,未經允許不得轉載。


相關文章