輕量級卷積神經網路的設計

CVer發表於2019-05-13

這篇文章將從一個證件檢測網路(Retinanet)的輕量化談起,簡潔地介紹,我在實操中使用到的設計原則和idea,並貼出相關的參考資料和成果供讀者參考。因此本文是一篇注重工程性、總結個人觀點的文章,存在不恰當的地方,請讀者在評論區指出,方便交流。

目前已有的輕量網路有:MobileNet V2和ShuffleNet v2為代表。在實際業務中,Retinanet僅需要檢測證件,不涉及過多的類別物體的定位和分類,因此,我認為僅僅更換上述兩個骨架網路來優化模型的效能是不夠的,需要針對證件檢測任務,專門設計一個更加輕量的卷積神經網路來提取、糅合特徵。

設計原則:

1. 更多的資料

輕量的淺層網路特徵提取能力不如深度網路,訓練也更需要技巧。假設保證有足夠多的訓練的資料,輕量網路訓練會更加容易。

Facebook研究院的一篇論文[1]提出了“資料蒸餾”的方法。實際上,標註資料相對未知資料較少,我使用已經訓練好、效果達標的base resnet50的retinanet來進行自動標註,得到一批10萬張機器標註的資料。這為後來的輕量網路設計奠定了資料基礎。我認為這是構建一個輕量網路必要的條件之一,網路結構的有效性驗證離不開大量的實驗結果來評估。

接下來,這一部分我將簡潔地介紹輕量CNN地設計的四個原則

2. 卷積層的輸入、輸出channels數目相同時,計算需要的MAC(memory access cost)最少

輕量級卷積神經網路的設計

3. 過多的分組卷積會增加MAC

對於1x1的分組卷積(例如:MobileNetv2的深度可分離卷積採用了分組卷積),其MAC和FLOPS的關係為:

輕量級卷積神經網路的設計

g代表分組卷積數量,很明顯g越大,MAC越大。詳細參考[2]

4. 網路結構的碎片化會減少可平行計算

這些碎片化更多是指網路中的多路徑連線,類似於short-cut,bottle neck等不同層特徵融合,還有如FPN。拖慢並行的一個很主要因素是,運算快的模組總是要等待運算慢的模組執行完畢。

輕量級卷積神經網路的設計


5. Element-wise操作會消耗較多的時間(也就是逐元素操作)

從表中第一行資料看出,當移除了ReLU和short-cut,大約提升了20%的速度。

輕量級卷積神經網路的設計

以上是從此篇論文[2]中轉譯過來的設計原則,在實操中,這四條原則需要靈活使用。

根據以上幾個原則進行網路的設計,可以將模型的引數量、訪存量降低很大一部分。

接下來介紹一些自己總結的經驗。

6. 網路的層數不宜過多

通常18層的網路屬於深層網路,在設計時,應選擇一個參考網路基線,我選擇的是resnet18。由於Retinanet使用了FPN特徵金字塔網路來融合各個不同尺度範圍的特徵,因此Retinanet仍然很“重”,需要儘可能壓縮骨架網路的冗餘,減少深度。

7. 首層卷積層用空洞卷積和深度可分離卷積替換

一個3x3,d=2的空洞卷積在感受野上,可以看作等效於5x5的卷積,提供比普通3x3的卷積更大的感受野,這在網路的淺層設計使用它有益。計算出網路各個層佔有的MAC和引數量,將引數量和計算量“重”的卷積層替換成深度可分離卷積層,可以降低模型的引數量。

這裡提供一個計算pytorch 模型的MAC和FLOPs的python packages[3]

if __name__ == "__main__":
    from ptflops import get_model_complexity_info

    net = SNet(num_classes=1)
    x = torch.Tensor(1, 3, 224, 224)

    net.eval()

    if torch.cuda.is_available():
        net = net.cuda()
        x = x.cuda()

    with torch.cuda.device(0):
        flops, params = get_model_complexity_info(net, (224, 224), print_per_layer_stat=True, as_strings=True, is_cuda=True)
        print("FLOPS:", flops)
        print("PARAMS:", params)

output:

(regressionModel): RegressionModel(
    0.045 GMac, 27.305% MACs,
    (conv1): Conv2d(0.009 GMac, 5.257% MACs, 128, 256, kernel_size=(1, 1), stride=(1, 1))
    (act1): ReLU(0.0 GMac, 0.041% MACs, )
    (conv2): Conv2d(0.017 GMac, 10.472% MACs, 256, 256, kernel_size=(1, 1), stride=(1, 1))
    (act2): ReLU(0.0 GMac, 0.041% MACs, )
    (conv3): Conv2d(0.017 GMac, 10.472% MACs, 256, 256, kernel_size=(1, 1), stride=(1, 1))
    (act3): ReLU(0.0 GMac, 0.041% MACs, )
    (output): Conv2d(0.002 GMac, 0.982% MACs, 256, 24, kernel_size=(1, 1), stride=(1, 1))
  )
  (classificationModel): ClassificationModel(
    0.044 GMac, 26.569% MACs,
    (conv1): Conv2d(0.009 GMac, 5.257% MACs, 128, 256, kernel_size=(1, 1), stride=(1, 1))
    (act1): ReLU(0.0 GMac, 0.041% MACs, )
    (conv2): Conv2d(0.017 GMac, 10.472% MACs, 256, 256, kernel_size=(1, 1), stride=(1, 1))
    (act2): ReLU(0.0 GMac, 0.041% MACs, )
    (conv3): Conv2d(0.017 GMac, 10.472% MACs, 256, 256, kernel_size=(1, 1), stride=(1, 1))
    (act3): ReLU(0.0 GMac, 0.041% MACs, )
    (output): Conv2d(0.0 GMac, 0.245% MACs, 256, 6, kernel_size=(1, 1), stride=(1, 1))
    (output_act): Sigmoid(0.0 GMac, 0.000% MACs, )
  )

8. Group Normalization 替換 Batch Normalization

BN在諸多論文中已經被證明了一些缺陷,而訓練目標檢測網路耗費視訊記憶體,開銷巨大,通常凍結BN來訓練,原因是小批次會讓BN失效,影響訓練的穩定性。建議一個BN的替代--GN,pytorch 0.4.1內建了GN的支援。

9. 減少不必要的shortcut連線和RELU層

網路不夠深,沒有必要使用shortcut連線,不必要的shortcut會增加計算量。RELU與shortcut一樣都會增加計算量。同樣RELU沒有必要每一個卷積後連線(需要實際訓練考慮刪減RELU)。

10. 善用1x1卷積

1x1卷積可以改變通道數,而不改變特徵圖的空間解析度,引數量低,計算效率也高。如使用kernel size=3,stride=1,padding=1,可以保證特徵圖的空間解析度不變,1x1的卷積設定stride=1,padding=0達到相同的目的,而且1x1卷積運算的效率目前有很多底層演算法支援,效率更高。[5x1] x [1x5] 兩個卷積可以替換5x5卷積,同樣可以減少模型引數。

11. 降低通道數

降低通道數可以減少特徵圖的輸出大小,視訊記憶體佔用量下降明顯。參考原則2

12. 設計一個新的骨架網路找對參考網路

一個好的骨架網路需要大量的實驗來支撐它的驗證,因此在工程上,參考一些實時網路結構設計自己的骨架網路,事半功倍。我在實踐中,參考了這篇[4]paper的骨架來設計自己的輕量網路。

總結

我根據以上的原則和經驗對Retinanet進行瘦身,不僅侷限於骨架的新設計,FPN支路瘦身,兩個子網路(迴歸網路和分類網路)均進行了修改,期望效能指標FPS提升到63,增幅180%。

FPS

輕量級卷積神經網路的設計

mAP

輕量級卷積神經網路的設計

Model size

輕量級卷積神經網路的設計

注:本文中部分觀點參考來源

1 https://towardsdatascience.com/types-of-convolutions-in-deep-learning-717013397f4d

2 The Receptive Field, the Effective RF, and how it's hurting your results

https://www.linkedin.com/pulse/receptive-field-effective-rf-how-its-hurting-your-rosenberg/

3 https://arxiv.org/abs/1807.11164

4 mp.weixin.qq.com/s?

參考

  1. Data Distillation Towards Omni-Supervised Learning

  2. ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design

  3. https://github.com/zhouyuangan/flops-counter.pytorch

  4. ThunderNet: TowardsReal-timeGenericObjectDetection

相關文章