利用LSTM思想來做CNN剪枝,北大提出Gate Decorator
選自arXiv
作者:Zhonghui You等
機器之心編譯
參與:思源、一鳴
利用LSTM基本思想門控機制進行剪枝?讓模型自己決定哪些卷積核可以扔。
還記得在理解 LSTM 的時候,我們會發現,它用一種門控機制記住重要的資訊而遺忘不重要的資訊。在此之後,很多機器學習方法都受到了門控機制的影響,包括 Highway Network 和 GRU 等等。北大的研究者同樣也是,它們將門控機制加入到 CNN 剪枝中,讓模型自己決定哪些濾波器不太重要,那麼它們就可以刪除了。
其實對濾波器進行剪枝是一種最為有效的、用於加速和壓縮卷積神經網路的方法。在這篇論文中,來自北大的研究者提出了一種全域性濾波器剪枝的演算法,名為「門裝飾器(gate decorator)」。這一演算法可以透過將輸出和通道方向的尺度因子(門)相乘,進而改變標準的 CNN 模組。當這種尺度因子被設0的時候,就如同移除了對應的濾波器。
研究人員使用了泰勒展開,用於估計因設定了尺度因子為 0 時對損失函式造成的影響,並用這種估計值來給全域性濾波器的重要性進行打分排序。接著,研究者移除哪些不重要的濾波器。在剪枝後,研究人員將所有的尺度因子合併到原始的模組中,因此不需要引入特別的運算或架構。此外,為了提升剪枝的準確率,研究者還提出了一種迭代式的剪枝架構—— Tick-Tock。
圖 1:濾波器剪枝圖示。第 i 個層有4個濾波器(通道)。如果移除其中一個,對應的特徵對映就會消失,而輸入 i+1 層的通道也會變為3。
擴充套件實驗說明了研究者提出的方法的效果。例如,研究人員在 ResNet-56 上達到了剪枝比例最好的 SOTA,減少了 70% 的每秒浮點運算次數,但沒有帶來明顯的準確率降低。
在 ImageNet 上訓練的 ResNet-50 上,研究者減少了 40% 的每秒浮點運算次數,且在 top-1 準確率上超過了基線模型 0.31%。在研究中使用了多種資料,包括 CIFAR-10、CIFAR-100、CUB-200、ImageNet ILSVRC-12 和 PASCAL VOC 2011。
本文的主要貢獻包括兩個部分:第一部分是「門裝飾器」演算法,用於解決 GFIR 問題。第二部分是 Tick-Tock 剪枝框架,用於提升剪枝準確率。
具體而言,研究者展示瞭如何將門裝飾器用於批歸一化操作,並將這種方法命名為門批歸一化(GBN)。給定預訓練模型,研究者在剪枝前將歸一化模組轉換成門批歸一化。剪枝結束後,他們將門批歸一化還原為批歸一化。透過這樣的方法,不需要給模型引入特殊的運算或架構。
- 論文地址:
- 實現地址:
門控剪枝到底怎麼做
那麼到底怎樣使用門控機制解決全域性濾波器重要性排序呢?研究者表示他們會先將 Gate Decorator 應用到批歸一化機制中,然後使用一種名為 Tick-Tock 的迭代剪枝框架來獲得更好的剪枝準確率,最後再採用分組剪枝(Group Pruning)技術解決待條件的剪枝問題,例如剪枝帶殘差連線的網路。
上面簡要展示了敘述了門控剪枝三步走,後面會做一個簡單的介紹,當然更詳細的內容可查閱原論文。
門控批歸一化
研究者將 Gate Decorator應用到批歸一化中,並將該模組稱之為門控批歸一化(GBN),門控批歸一化如下方程7所示,它和標準批歸一化的不同之處在於 φ arrow的門控選擇。其中 φ arrow 是 φ 的一個向量,c 是 Z_in 的通道數。
如果 φ arrow 中的元素是零,那麼就表示它對應的通道被裁減了。此外,對於不使用BN 的網路,我們也可以直接將 Gate Decorator 應用到卷積運算中,從而達到門控剪枝的效果。
Tick-Tock 剪枝框架
研究者還引進了一種迭代式的剪枝框架,從而提升剪枝準確率,他們將該框架稱為Tick-Tok。其中 Tick 階段會在訓練資料的子集上執行,卷積核會被設定為不可更新狀態。而 Tock 階段使用全部訓練資料,並將稀疏約束 φ 新增到損失函式中。
圖2:Tick-Tock剪枝框架圖示。
其中 Tick 階段主要希望能實現以下三個目標:加速剪枝過程;計算每一個濾波器的重要性分數 Θ;降低前面剪枝引起的內部協變數遷移問題。
在 Tick 階段中,研究者會在訓練資料的子集中訓練一個 Epoch,我們僅允許門控 φ 和最終的線性層能更新,這樣能大大降低小資料集上的過擬合風險。透過訓練後,模型會根據重要性分數 Θ 排序所有的濾波器,並將不那麼重要的濾波器移除。
在 Tock 階段前,Tick 階段能重複 T 次。Tock 階段會微調網路以降低總體誤差,這些誤差可能是由於一處濾波器造成的。此外,Tock 階段和一般的微調過程有兩大不同:微調比 Tock 要訓練更多的 Epoch;微調並不會給損失函式加上稀疏性約束。
分組剪枝:解決帶約束的剪枝問題
ResNet 和其變體包含殘差連線,也就是在兩個殘差塊產生的特徵圖上執行元素級的加法。如果單獨修剪每個層的濾波器,可能會導致殘差連線中特徵圖對不齊。這可以視為一種帶約束的剪枝問題,我們希望剪枝是在對齊特徵圖的條件下完成的。
為了解決無法對齊的問題,作者們提出了分組剪枝:將透過純殘差方式連線的 GBN 分配給同一組。純殘差連線是指在側分支上沒有卷積層的一種方式,如圖3所示。
圖3:組剪枝展示。同樣顏色的GBN屬於同一組。
每一組可以視為一個 Virtual GBN,它的所有組成卷積共享了相同的剪枝模式。並且在分組中,濾波器的重要性分數就是成員卷積分數的和。
實驗設定和資料集
資料集
研究者使用了多種資料集,包括 CIFAR-10,CIFAR-100,CUB-200, ImageNet ILSVRC-12和 PASCAL VOC 2011。CIFAR-10 資料集包括了50K的訓練資料和10K的測試資料。CIFAR-100和CIFAR-10相同,但有100個類別,每個類別有600張圖片。CUB-200包括了將近6000張訓練圖片和5700張測試圖片,涵蓋了200種鳥類。ImageNet ILSVRC-12有128萬訓練影像和50K的測試影像,覆蓋1000個類別。研究者還使用了PASCAL VOC 2011分割資料集和其擴充套件資料集SBD,它有20個類別,共8498張訓練樣本圖片和2857張測試樣本圖片。
被剪枝的模型
研究者使用了三種網路架構進行剪枝:VGGNet、ResNet和FCN。所有的網路都使用SGD進行訓練,權重衰減和動量超引數分別設定為10-4和0.9。
研究者使用了多種訓練資料和不同的批大小對這些網路進行了訓練,同時加入了一些資料增強的方法。
在剪枝階段,研究者在每個Tick階段剪去ResNet0.2%的濾波器,在VGG和FCN上減去1%的濾波器。在每10個Tick操作後進行一次Tock操作。
剪枝效果
表1:在 ResNet-56上,使用CIFAR-10訓練的模型剪枝後的表現。基線準確率為93.1%。
表 2:在ResNet-50上,使用ImageNe訓練的模型剪枝後的表現。P.Top-1、P.Top-5 分別表示 top-1和 top-5剪枝後的模型在驗證集上的單中心裁剪準確率。[Top-1] ↓ 和 [Top-5] ↓分別表示剪枝後模型準確率和基線模型相比的下降情況。Global 表示這一剪枝方法是否是全域性濾波器剪枝演算法。
圖4:VGG-16-M在CUB-200資料集上的剪枝效果。
下圖5的基線模型是VGG-16-M,他在CIFAR-100上的測試準確率為73.19%。其中「shrunk」版表示將所有卷積層的通道數減半,因此將FLOPs降低到了基線模型的1/4,從頭訓練後它的測試準確率會降低1.98%。「pruned」版表示採用Tick-Tock框架進行剪枝的結果,它的測試準確率會降低1.3%。
如果我們從頭訓練「pruned」版模型,那麼它的準確率能達到71.02%,相當於降低了2.17%。不過重要的是,「pruned」版模型的引數量只有「shrunk」版模型的1/3。
圖5:兩種網路的效果和通道數對比,它們有相同的FLOPs。
來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69946223/viewspace-2658274/,如需轉載,請註明出處,否則將追究法律責任。
相關文章
- 利用LSTM做語言情感分類
- 讓CNN跑得更快,騰訊優圖提出全域性和動態過濾器剪枝CNN過濾器
- Ross、何愷明等人提出:渲染思路做影像分割,提升Mask R-CNN效能ROSCNN
- 初次參與Golden -Gate POC,希望有機會做Golden -Gate for ERPGo
- 模型剪枝:剪枝粒度、剪枝標準、剪枝時機、剪枝頻率模型
- Python Decorator的來龍Python
- Ross、何愷明等人提出PointRend:渲染思路做影像分割,顯著提升Mask R-CNN效能ROSCNN
- 賈佳亞等提出Fast Point R-CNN,利用點雲快速高效檢測3D目標ASTCNN3D
- Taro下利用Decorator快速實現小程式分享
- 利用Decorator和SourceMap優化JavaScript錯誤堆疊優化JavaScript
- 如何利用jenkins來做android自動化JenkinsAndroid
- 原作者帶隊,LSTM捲土重來之Vision-LSTM出世
- 手把手教你開發CNN LSTM模型,並應用在Keras中(附程式碼)CNN模型Keras
- 如何基於TensorFlow使用LSTM和CNN實現時序分類任務CNN
- Dropout可能要換了,Hinton等研究者提出神似剪枝的Targeted Dropout
- 利用歸檔來做資料檔案的恢復
- 「完結」總結12大CNN主流模型架構設計思想CNN模型架構
- Golden Gate 初探Go
- Token化一切,甚至網路!北大&谷歌&馬普所提出TokenFormer,Transformer從來沒有這麼靈活過!谷歌ORM
- 北大博士生提出CAE,下游任務泛化能力優於何愷明MAE
- 論文翻譯:2018_LSTM剪枝_Learning intrinsic sparse structures within long short-term memoryStruct
- 清華、李飛飛團隊等提出強記憶力 E3D-LSTM 網路3D
- 利用RMAN做TSPITR
- 原作者帶隊,LSTM真殺回來了!
- 用介面的思想來理解GraphQL
- 如何用LSTMs做預測?(附程式碼)| 博士帶你學LSTM
- ICCV 2019 | 北大、華為聯合提出無需資料集的Student Networks
- 李飛飛「空間智慧」之後,上交、智源、北大等提出空間大模型SpatialBot大模型
- LSTM理解
- RNN、LSTMRNN
- 乾貨|如何利用CNN建立計算機視覺模型?CNN計算機視覺模型
- Alpha-Beta 剪枝
- 無所不能的Embedding5 - skip-thought的兄弟們[Trim/CNN-LSTM/quick-thought]CNNUI
- 智慧家居暴露隱私?港中文等利用LSTM攻克IoT安全設定
- 優於Mask R-CNN,港中文&騰訊優圖提出PANet例項分割框架CNN框架
- LSTM捲土重來!xLSTM:一舉超越Mamba、Transformer!ORM
- 北大獲中國首個WWW大會最佳論文獎,提出ELSA跨語言情感分析模型模型
- ORACLE golden gate 機制OracleGo