DMCP: Differentiable Markov Channel Pruning for Neural Networks 閱讀筆記

zhwangye發表於2020-11-23

DMCP: Differentiable Markov Channel Pruning for Neural Networks

神經網路的可微馬爾可夫通道剪枝
論文地址

0 前言

本文提出了DMCP剪枝演算法,主要思想是:

  1. 將通道剪枝作為馬爾可夫過程進行建模
  2. 使用網路原始損失函式和預算正則化(限制模型的FLOPs)對剪枝引數進行優化
  3. 使用DS或ES策略對馬爾可夫模型進行剪枝,得到剪枝模型
  4. 對剪枝後的模型進行重頭開始訓練,並評估子網路的效能

文中提出的可微體現在, 預算正則化無法被直接優化,提出了可微分的方法在訓練時對其進行優化

對照:
預算正則化: Budget Regularization

1 馬爾可夫建模

1.1 剪枝過程的定義

使用 M ( L ( 1 ) , L ( 2 ) , … , L ( N ) ) M\left(L^{(1)}, L^{(2)}, \ldots, L^{(N)}\right) M(L(1),L(2),,L(N))表示N層沒有剪枝的網路模型,公式裡的 L ( i ) L^{(i)} L(i)表示網路模型中的第 i i i層。網路中的第 i i i層中第 k k k個通道的輸出可以由公式(1)給出,
O k ( i ) = w k ( i ) ⊙ x , k = 1 , 2 , … , C o u t ( i )   ( 1 ) O_{k}^{(i)}=w_{k}^{(i)} \odot x, k=1,2, \ldots, C_{o u t}^{(i)} (1) Ok(i)=wk(i)x,k=1,2,,Cout(i) (1)
O k ( i ) O_{k}^{(i)} Ok(i)表示第 i i i層中的第 k k k個通道的輸出, w k ( i ) w_{k}^{(i)} wk(i)表示 L ( i ) L^{(i)} L(i)層中的第 k k k個Filter, ⊙ \odot 表示卷積操作. C o u t ( i ) C_{o u t}^{(i)} Cout(i)表示 L ( i ) L^{(i)} L(i)層中的第 k k k個Filter的個數. 為了方便起見, 如果沒有特別的說明, 後面的符號中都將省略上標.

說明:
" O k ( i ) O_{k}^{(i)} Ok(i)表示第 i i i層中的第 k k k個通道的輸出, w k ( i ) w_{k}^{(i)} wk(i)表示 L ( i ) L^{(i)} L(i)層中的第 k k k個Filter"的通道Filter說明, 如下圖所示, 網路的輸入為6x6x3大小特徵圖, Filter(卷積核)大小為3x3x3x2, 其中x2表示卷積核(Filter)的個數為2, 經過卷積操作後, 輸出為4x4x2大小的特徵圖, 這裡的2即為輸出特徵圖的通道. 綜上, Filter即為卷積核, 通道為輸出特徵圖的’深度’, Filter的個數等於輸出特徵圖的通道數.
在這裡插入圖片描述

通道和Filter說明圖片
[圖片來源](https://indoml.com/2018/03/07/student-notes-convolutional-neural-networks-cnn-introduction)

i i i層的馬爾可夫模型如下, 狀態 S k ( 1 ≤ k ≤ C out ) S_{k}\left(1 \leq k \leq C_{\text {out}}\right) Sk(1kCout)表示剪枝中保留的第 k k k個通道, 從 S k S_k Sk S k + 1 S_{k+1} Sk+1的轉移概率 p k p_k pk表示在第 k k k個通道保留的前提下,保留地 k + 1 k+1 k+1個通道的概率. 當任何其他狀態轉移到終端狀態T, 即表明剪枝過程結束. 該過程有這樣的性質, 如果 L L L C c o u t C_{cout} Ccout中的 k k k個通道被保留了, 那麼他們必須是前 k k k個通道. 也就是說, 如果第 k k k個通道被保留了, 那麼前 k − 1 k-1 k1個通道也會被保留. 進一步, 可以總結出, 保留 k + 1 k+1 k+1個通道條件獨立於前 k − 1 k-1 k1個通道當第 k k k個通道保留的前提下.
在這裡插入圖片描述
上圖中 1 − p k 1-p_k 1pk表示達到終止狀態的概率.

1.2 通過馬爾可夫過程進行通道剪枝

使用一堆可學習的引數對馬爾可夫轉移狀態進行引數化,這堆可學習的引數被稱為結構引數. p ( w 1 , w 2 , … , w k − 1 ) p\left(w_{1}, w_{2}, \ldots, w_{k}-1\right) p(w1,w2,,wk1)用來表示前 k − 1 k-1 k1個通道的保留概率, 前 k k k個通道的保留概率可以如下表示,
p ( w 1 , … , w k ) = p ( w k ∣ w 1 , … , w k − 1 ) p ( w 1 , … , w k − 1 )   ( 2 ) p\left(w_{1}, \ldots, w_{k}\right)=p\left(w_{k} \mid w_{1}, \ldots, w_{k-1}\right) p\left(w_{1}, \ldots, w_{k-1}\right) (2) p(w1,,wk)=p(wkw1,,wk1)p(w1,,wk1) (2)
上面公式中的 p ( w k ∣ w 1 , … , w k − 1 ) p\left(w_{k} \mid w_{1}, \ldots, w_{k-1}\right) p(wkw1,,wk1)表示在保留前 k − 1 k-1 k1個通道的前提下, 保留的 k k k個通道的概率. 由於 w k w_k wk w k − 1 w_{k-1} wk1保留的前提下, 條件獨立於 { w 1 , w 2 , … w k − 2 } \left\{w_{1}, w_{2}, \ldots w_{k}-2\right\} {w1,w2,wk2}, 所以可以將公式2改寫成,
p k = p ( w k ∣ w 1 , w 2 , … , w k − 1 ) = p ( w k ∣ w k − 1 )   ( 3 ) p_{k}=p\left(w_{k} \mid w_{1}, w_{2}, \ldots, w_{k-1}\right)=p\left(w_{k} \mid w_{k-1}\right) (3) pk=p(wkw1,w2,,wk1)=p(wkwk1) (3)
p ( w k ∣ ¬ w k − 1 ) = 0   ( 4 ) p\left(w_{k} \mid \neg w_{k-1}\right)=0 (4) p(wk¬wk1)=0 (4)
¬ w k − 1 \neg w_{k-1} ¬wk1表示 k − 1 k-1 k1個通道被捨棄. 轉移概率 P = { p 1 , p 2 , . . p C o u t } P=\left\{p_{1}, p_{2}, . . p_{C_{o u t}}\right\} P={p1,p2,..pCout}可以通過公式3來定義. 並且使用結構引數 A = { α 1 , α 2 , … , α c out } A= \left\{\alpha_{1}, \alpha_{2}, \ldots, \alpha_{c_{\text {out}}}\right\} A={α1,α2,,αcout}來引數化轉移概率 P P P, p k p_k pk由公式5計算得出,
p k = { 1 k = 1 sigmoid ( α k ) = 1 1 + e − α k k = 2 , … , C out , α k ∈ A   ( 5 ) p_{k}=\left\{\begin{array}{ll} 1 & k=1 \\ \text {sigmoid}\left(\alpha_{k}\right)=\frac{1}{1+e^{-\alpha_{k}}} & k=2, \ldots, C_{\text {out}}, \alpha_{k} \in A \end{array}\right. (5) pk={1sigmoid(αk)=1+eαk1k=1k=2,,Cout,αkA (5)
我們設定每個卷積操作至少保留一個通道, 所以, p 1 = p ( w 1 ) = 1 p_{1}=p\left(w_{1}\right)=1 p1=p(w1)=1 .對取樣的通道 w k w_k wk的邊緣概率使用 p ( w k ) p(w_k) p(wk)來表示, 並使用公式6計算,
p ( w k ) = p ( w k ∣ w k − 1 ) p ( w k − 1 ) + p ( w k ∣ ¬ w k − 1 ) p ( ¬ w k − 1 ) = p ( w k ∣ w k − 1 ) p ( w k − 1 ) + 0 = p ( w 1 ) ∏ i = 2 k p ( w i ∣ w i − 1 ) = ∏ i = 1 k p i ( 6 ) \begin{aligned} p\left(w_{k}\right) &=p\left(w_{k} \mid w_{k-1}\right) p\left(w_{k-1}\right)+p\left(w_{k} \mid \neg w_{k-1}\right) p\left(\neg w_{k-1}\right) \\ &=p\left(w_{k} \mid w_{k-1}\right) p\left(w_{k-1}\right)+0 \\ &=p\left(w_{1}\right) \prod_{i=2}^{k} p\left(w_{i} \mid w_{i-1}\right)=\prod_{i=1}^{k} p_{i} \end{aligned}(6) p(wk)=p(wkwk1)p(wk1)+p(wk¬wk1)p(¬wk1)=p(wkwk1)p(wk1)+0=p(w1)i=2kp(wiwi1)=i=1kpi(6)
結構引數會被新增到剪枝前的網路中進行優化, 通過公式7實現,
O ^ k = O k × p ( w k )   ( 7 ) \hat{O}_{k}=O_{k} \times p\left(w_{k}\right)  (7) O^k=Ok×p(wk) (7)
O ^ k \hat{O}_{k} O^k表示第 k k k個通道的最終輸出, 如果的 w k w_k wk(第 k k k個Filter)進行剪枝, 則通過將其置0即可.

注意,由於BN會對第 i i i層的輸出進行縮放, 所以不能再卷積層之後直接實現公式7, 而是將其放在BN層之後.

1.3 短連線問題

MobileNetV2和ResNet都有含短連線的殘差塊, 由於element-wise操作, 最後一個卷積層的通道數必須和前一個塊的通道數相等.我們使用權重共享解決此問題.

1.4 Budget 正則化

這是對網路模型的總體FLOPs的約束, L L L層中所期望的通道 E ( c h a n n e l ) E(channel) E(channel)可以用公式8表示,
E ( channel ) = ∑ i = 1 C out p ( w i ) ( 8 ) E(\text {channel})=\sum_{i=1}^{C_{\text {out}}} p\left(w_{i}\right)(8) E(channel)=i=1Coutp(wi)(8)
L L L層中期望輸入通道 E ( i n ) E(in) E(in)和輸出通道 E ( o u t ) E(out) E(out)可以通過公式8計算, 期望FLOPs E ( L F L O P s ) E(L_{FLOPs}) E(LFLOPs)可以通過以下公式計算,
E ( L F L O P s ) = E ( out ) × E ( kernel_op ) ( 9 ) E\left(L_{F L O P s}\right)=\quad E(\text {out}) \times E(\text {kernel\_op}) (9) E(LFLOPs)=E(out)×E(kernel_op)(9)
E ( kernel_op ) = E ( in ) groups × #  channel_op  ( 10 ) E(\text {kernel\_op})=\frac{E(\text {in})}{\text {groups}} \times \# \text { channel\_op } (10) E(kernel_op)=groupsE(in)×# channel_op (10)
#  channel_op  = ( S I + S P − S K stride + 1 ) × S K × S K ( 11 ) \# \text { channel\_op }=\left(\frac{S_{I}+S_{P}-S_{K}}{\text {stride}}+1\right) \times S_{K} \times S_{K}(11) # channel_op =(strideSI+SPSK+1)×SK×SK(11)
這裡的groups為分組卷積的組數. S I S_I SI S K S_K SK分別表示輸入特徵圖和卷積核的寬(或者高). S P S_P SP表示 填充大小,stride表示卷積步長. 整個模型的期望FLOPs E ( N F L O P s ) E(N_{FLOPs}) E(NFLOPs)可由公式12計算,
E ( N F L O P s ) = ∑ l = 1 N E ( l ) ( L F L O P s ) ( 12 ) E\left(N_{F L O P s}\right)=\sum_{l=1}^{N} E^{(l)}\left(L_{F L O P s}\right) (12) E(NFLOPs)=l=1NE(l)(LFLOPs)(12)
公式12中的 N N N表示網路中卷積層的數量. 通過該公式, 可以通過梯度下降來優化FLOPs.

1.5 損失函式

設定要達到的FLOPs為 F L O P s t a r g e t FLOPs_{target} FLOPstarget, 使用公式12來公表示可微分的budget正則化的損失函式 l o s s r e g loss_{reg} lossreg,
loss ⁡ r e g = log ⁡ ( ∣ E ( N F L O P s ) − F L O P s target ∣ ) ( 13 ) \operatorname{loss}_{r e g}=\log \left(\left|E\left(N_{F L O P s}\right)-F L O P s_{\text {target}}\right|\right) (13) lossreg=log(E(NFLOPs)FLOPstarget)(13)
為了讓 E ( N F L O P s ) E(N_{FLOPs}) E(NFLOPs)嚴格小於 F L O P s t a r g e t FLOPs_{target} FLOPstarget但又不對目標太敏感, 所以給損失函式新增了單邊距, 當滿足 γ × F L O P s target ≤ \gamma \times F L O P s_{\text {target}} \leq γ×FLOPstarget E ( N F L O P s ) ≤ F L O P s target E\left(N_{F L O P s}\right) \leq F L O P s_{\text {target}} E(NFLOPs)FLOPstarget 時,損失函式會變成0. γ < 1 \gamma < 1 γ<1是公差比可由使用者給定.

在更新權重時, F L O P s FLOPs FLOPs損失函式對權重沒有影響, 權重損失為,
L o s s w e i g h t = l o s s c l s ( 14 ) Loss_{weight}= loss_{cls} (14) Lossweight=losscls(14)
l o s s c l s loss_{cls} losscls為分類任務的交叉熵損失函式.當更新結構引數時, 損失函式如下,
L o s s a r c h = l o s s c l s + λ reg loss reg Loss_{arch}=loss_{cls}+\lambda_{\text {reg}} \text {loss}_{\text {reg}} Lossarch=losscls+λreglossreg
λ reg \lambda_{\text {reg}} λreg是用於平衡兩個損失的超引數.

沒有在結構引數中新增權重衰減, 由於當保留通道的概率接近0或者1時, 可學習引數 α α α的範數將變得非常大,這會使得他們趨近於0並進行優化.

2 訓練

訓練分為未剪枝網路權重更新和結構引數更新兩個步驟, 而且這兩個步驟會在訓練過程中被迭代呼叫.

2.1 步驟1 未剪枝網路權重更新

由公式6的定義可知, 保留第 k k k個通道概率的也可以被認為是保留第前 k − 1 k-1 k1個通道的概率. 我們的方法可視為在網路前向傳播更新結構引數時對所有子網路結構的軟取樣.網路的通道的權重都相等, 這樣不好直觀地對通道選擇作為馬爾可夫過程進行建模. 我改進了’'三明治法則", 命名為"變體三明治法則", 將該方法加入訓練過程中, 使得在未剪枝模型中通道組比他之後的通道組更重要, 讓通道的權重變得不相等. 更加權重的排序, 就可以選擇出效能最好的子網路結構.

與原始的"三明治法則"有兩點不同: 1. 每一層隨機取樣的閾值不同; 2. 隨機取樣的閾值服從馬爾可夫過程中的結構引數分佈. 值得注意的是, 我們注重出現頻率高的子網路結構.
在這裡插入圖片描述

2.2 步驟2 結構引數更新

這個步驟中只更新結構引數, 由公式6可知, 結構引數結合在原始網路的每個卷積層的輸出中, 這使得反向傳播的梯度能夠對其進行優化. 通過下面的公式可以讓梯度回傳到 α \alpha α,
∂ L o s s ∂ α j ( i ) = ∑ k = 1 C o u t ∂ L o s s ∂ O k ( i ) ^ × ∂ O k ( i ) ^ α j ( i ) ( 16 ) \frac{\partial L o s s}{\partial \alpha_{j}^{(i)}}=\sum_{k=1}^{C_{o u t}} \frac{\partial L o s s}{\partial \hat{O_{k}^{(i)}}} \times \frac{\partial \hat{O_{k}^{(i)}}}{\alpha_{j}^{(i)}} (16) αj(i)Loss=k=1CoutOk(i)^Loss×αj(i)Ok(i)^(16)
∂ O k ( i ) ^ α j ( i ) = { 0 , k < j ∂ p k ∂ α j O k ( i ) ∏ r ∈ { r ∣ r ≠ j  and  r ≤ k } p r , k ≥ j ( 17 ) \frac{\partial \hat{O_{k}^{(i)}}}{\alpha_{j}^{(i)}}=\left\{\begin{array}{ll} 0 & , k<j \\ \frac{\partial p_{k}}{\partial \alpha_{j}} O_{k}^{(i)} \prod_{r \in\{r \mid r \neq j \text { and } r \leq k\}} p_{r} & , k \geq j \end{array}\right. (17) αj(i)Ok(i)^={0αjpkOk(i)r{rr=j and rk}pr,k<j,kj(17)
∂ p k ∂ α j = ( 1 − p k ) p k ( 18 ) \frac{\partial p_{k}}{\partial \alpha_{j}}=\left(1-p_{k}\right) p_{k} (18) αjpk=(1pk)pk(18)
為了進一步減小搜尋空間, 文中將通道均勻劃分成組(>10), 每個 α \alpha α引數代表一個組. 每層的組數是相同的.

2.3 網路預熱

在迭代呼叫步驟1和步驟2之前, 文中先執行幾個epoch步驟1來進行預熱, 這個過程中會通過隨機初始化的結構引數的馬爾可夫過程進行取樣得到子網路結構, 防止網路陷入區域性最優.

3 剪枝模型取樣

3.1 直接取樣DS

直接取樣時使用馬爾可夫過程中的轉移概率在網路的每一層進行獨立的取樣, 在這些取樣中只儲存滿足FLOPs約束的網路結構.

3.2 期望取樣ES

期望取樣是使用公式8計算出來的結果作為每一層需要取樣的通道數, 試驗中 l o s s r e g loss_{reg} lossreg會被優化到0, 所以FLOPs是滿足約束的.

DS得到的子網路結構的效能要略微優於ES, 但是由於該方法要測試多個子網路, 開銷比較大, 所以文中評估採用的ES方式.

相關文章