MAE自監督演算法介紹和基於EasyCV的復現

阿里雲大資料AI技術發表於2022-05-18

作者:賀弘、謙言、臨在

導言

自監督學習(Self-Supervised Learning)能利用大量無標註的資料進行表徵學習,然後在特定下游任務上對引數進行微調。透過這樣的方式,能夠在較少有標註資料上取得優於有監督學習方法的精度。近年來,自監督學習受到了越來越多的關注,如Yann Lecun也在 AAAI 上講 Self-Supervised Learning 是未來的大勢所趨。在CV領域湧現瞭如SwAV、MOCO、DINO、MoBY等一系列工作。MAE是kaiming繼MOCO之後在自監督學習領域的又一力作。首先,本文會對MAE進行解讀,然後基於EasyCV庫的精度復現過程及其中遇到的一些問題作出解答。

概述

MAE的做法很簡單:隨機mask掉圖片中的一些patch,然後透過模型去重建這些丟失的區域。包括兩個核心的設計:1.非對稱編碼-解碼結構 2.用較高的掩位元速率(75%)。透過這兩個設計MAE在預訓練過程中可以取得3倍以上的訓練速度和更高的精度,如ViT-Huge能夠透過ImageNet-1K資料上取得87.8%的準確率。

模型拆解

MAE屬於自編碼器(AutoEncoder)的一種,由編碼器和解碼器兩個部分組成。類似於常見的自編碼器,MAE會先透過編碼器將圖片patch對映到隱空間。然後,基於解碼器將隱空間上的特徵變數重構成圖片patch。和常見自編碼器的區別是非對稱的編碼解碼結構。這個非對稱性主要體現在以下兩點:

  1. 輕量化的解碼器結構

  2. 在編碼器階段,僅將未被mask掉的圖片patch作為輸入。在解碼器階段會將編碼器輸出的隱變數和mask token共同作為輸入去重建完成的圖片。

MAE自監督演算法介紹和基於EasyCV的復現

掩碼策略

首先,直接採用ViT的做法將圖片分成不重疊的patch(如vit-b會將圖片劃分成16x16的影像塊),然後透過均勻取樣策略對這些patch進行取樣,並丟棄未被選中的部分。MAE所採用的掩碼策略有如下兩個特點:

1.在演算法中,使用了75%的masking ratio來丟棄圖片patch。作者指出,透過high masking ratio可以有效減少輸入的冗餘程度,使重建任務不能夠透過簡單的參考鄰近patch來完成。文中,也透過實驗證明了這一觀點。


MAE自監督演算法介紹和基於EasyCV的復現

關於Masking ratio的實驗是MAE最精彩的一部分, 隨著mask ratio的增加,fine-tuning和linear proing的精度逐漸攀升,甚至到75%還沒有下降,這一點打破了BERT(15%)、BEiT(40%)的做法,進一步將mask 預訓練方式在NLP領域的成功在CV領域實現複製。


2.採用了均勻取樣策略可以有效的避免potential center bias(丟棄掉的patch都靠近圖片中心)。對mask策略的消去實驗如下表所示。


MAE自監督演算法介紹和基於EasyCV的復現

編碼器

MAE encoder採用的是ViT結構。在對影像patch進行取樣後,僅保留25%未被mask的影像patch作為輸入,透過linear Projection進行編碼後,加上positional embedding,然後輸入到一系列的Transformer blocks中。相比於Bert中用mask token來代替被mask區域的做法,MAE encoder直接捨棄掉了mask的部分,透過這種方式可以有效的減少預訓練過程中需要消耗的計算資源和訓練時間。

文中,作者對編碼器是否保留mask token進行了消融實驗,可以看出在編碼器階段捨棄mask token不會對預訓練模型的表徵能力造成影響,同時能夠顯著的加速訓練程式。

MAE自監督演算法介紹和基於EasyCV的復現

解碼器

MAE decoder由一連串的Transfomer block組成。和encoder不同的是,MAE decoder的輸入不僅包括未被mask的影像patch經過encoder編碼後的特徵,還包括了被mask掉的部分。對於mask掉部分的輸入,會用一個共享引數,且可學習的mask token代替作為輸入。除此之外,為了保證不同的mask token能夠區分在影像中的不同位置,在輸入到decoder之前,會對整體的輸入加上positional embedding。

在MAE中,解碼器僅會在預訓練階段用於圖片的重建工作。文中採用了輕量化的解碼器結構,對於每個token的計算量僅有相對於解碼器的10%以下。透過這種設計,就算在解碼階段用了完整數量的token作為輸入,對計算資源的消耗也不會顯著增加。

文中,作者對解碼器的depth和width兩個維度進行對比實驗,可以看出一個較輕量化的解碼器,就足以是模型學習到有效的表徵。

MAE自監督演算法介紹和基於EasyCV的復現

重建目標

MAE預訓練任務的目標是重建被mask掉的畫素值。MAE decoder輸出關於每個影像patch的表徵後,會經過一個linear projection層對映成與影像畫素數目相同維度的向量(PxPx3)。僅採用MSE作為損失函式,計算預測向量和被mask掉畫素值之前的MSE loss。

需要額外指出的是,作者使用了歸一化後的影像patch作為重建的目標。透過實驗證明,這種做法可以提升模型的表徵能力。

MAE自監督演算法介紹和基於EasyCV的復現

模型評價

文中除了從linear probing和Finetuning兩個角度對模型的表徵能力做出評價外,還採用了Partial Fine-tuning的方式進行評價,相比於linear probing這種之前普遍採用的評價方式,能夠更好的反映預訓練模型對非線性特徵的表徵能力。從下圖可以看出,MAE演算法僅僅對一個transformer block進行fintune精度就從73.5%提升到81%。同時與MOCOv3相比,MOCOv3雖然在linear probing的時候具有更高的精度,但是在partial fine-tuning時,MAE的精度都要高於MOCOv3。可以看出,MAE雖然對線性特徵的表徵能力要弱於MOCOv3,但是具有更好的非線性特徵表徵能力。

MAE自監督演算法介紹和基於EasyCV的復現

EasyCV介紹

EasyCV是阿里巴巴開源的基於Pytorch,以自監督學習和Transformer技術為核心的 all-in-one 視覺演算法建模工具。在資料層面,EasyCV提供了提供了不同資料來源(data_source)的抽象,支援多種開源資料集例如Cifar、ImageNet、CoCo等,並將各種資料預處理抽象成若干獨立的pipeline,可以透過配置檔案靈活的配置資料預處理流程。在API層面,提供了統一的訓練、評估、模型匯出、預測的API。因此,基於EasyCV,僅需要實現模型部分的程式碼,就可以很便捷的完成MAE的復現。

除此之外,EasyCV支援aliyun PAI產品中方便的進行部署(如PAI-DLC),無需多餘的修改即可在DLC上同時進行多機或者多組實驗,加快復現進度。

復現過程 & 踩坑總結

接下來我們介紹如何在EasyCV框架中進行MAE演算法的復現和踩坑總結,首先,說明一下預訓練的整體流程。

1.將輸入影像劃分成不同的patch,並將patch經過Linear Projection進行對映,再加上positional embedding得到image token

# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]

2.將image token按75%的比例進行隨機mask,透過隨機生成的張量noise進行argsort操作的方式來完成對image patch的隨機mask。其中,需要注意,該函式中額外傳回兩個引數mask和ids_restore。mask記錄了mask patch在原始圖片中的位置,用於後續損失函式的計算。ids_restore記錄了傳入encoder的image token在原始圖片中的位置,用於後續再decoder前進行unshuffle操作。

def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))
        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]
        # sort noise for each sample
        ids_shuffle = torch.argsort(
            noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)
        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(
            x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)
        return x_masked, mask, ids_restore

3.將保留的image token輸入到encoder得到image embeding

        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

4.將image embeding和mask token一起進行unshuffle操作,再加上positional embedding後,輸入到decoder中

# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
    x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = torch.gather(
    x_,
    dim=1,
    index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
x = torch.cat([x[:, :1, :], x_], dim=1)
# add pos embed
x = x + self.decoder_pos_embed

5.將輸出的vector與歸一化後的image patch計算mse loss,並反向傳播更新梯度。在計算loss時,有兩個需要注意的點。1、首先,需要對作為target的影像patch做歸一化。2、在計算損失函式時,只對mask patch的部分計算損失函式。

    def forward_loss(self, imgs, pred, mask):
        """compute loss
        Args:
            imgs: (N, 3, H, W)
            pred: (N, L, p*p*3)
            mask: (N, L), 0 is keep, 1 is remove,
        """
        target = self.patchify(imgs)
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5
        loss = (pred - target)**2
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch
        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
        return loss

精度復現

參考,我們在單機八卡V100的配置下,對ViT-base和ViT-large的在ImageNet1K上fintune的精度進行了復現。結果如下表所示。

Algorithm ImageNet1K Top-1(%) config
vit-b 400 epoch 83.13 mae_vit_base_patch16_8xb64_100e_lrdecay075_fintune
vit-b 1600 epoch 83.55 mae_vit_base_patch16_8xb64_100e_lrdecay065_fintune
vit-l 1600 epoch 85.70 mae_vit_large_patch16_8xb16_50e_lrdecay075_fintune

下面分享一下在復現過程中遇到的一些問題和調參,如有問題請指出。

  1. 在fintune時,MAE的實現使用了mixup+cutmix的資料增廣方式,若僅使用mixup精度會下降。
  2. 在fintune時,MAE中使用了所有token特徵求平均的方式作為分類head的輸入,而cls token作為輸入時精度會有下降。
  3. 在預訓練過程中,確保使用了足夠大的weight_decay(如官方設為0.05),否則在下游任務fintune時,很容易出現梯度爆炸的問題。而在下游分類任務fintune時,設定一個較小的weight,精度會有一些提升。(PS 在復現vit-l時,在pretrain時設定weight_decay 0.01,在fintune時會出現梯度爆炸)

下表展示了vit-b模型的復現過程上述過程的精度提升

parameter setting ImageNet1K Top-1(%)
vit-b 1600 epoch(mixup,cls token) 83.21
vit-b 1600 epoch(mixup+cutmix,cls token) 83.36
vit-b 1600 epoch(mixup+cutmix,global_pool) 83.55

我們在開源框架 EasyCV中復現了MAE演算法。詳細引數配置和實驗日誌參考github上的自監督modelzoo( )。

Tutorial

接下來,我們將透過一個實際的例子介紹如何基於EasyCV進行MAE演算法的預訓練和微調,也可以在該 連結檢視詳細步驟。

一、安裝依賴包

如果是在本地開發環境執行,可以參考該 連結安裝環境。若使用PAI-DSW進行實驗則無需安裝相關依賴,在PAI-DSW docker中已內建相關環境。

二、資料準備

自監督訓練只需要提供無標註圖片即可進行, 你可以下載 ImageNet 資料,或者使用你自己的圖片資料。需要提供一個包含若干圖片的資料夾路徑p,以及一個檔案列表,檔案列表中是每個圖片相對圖片目錄p的路徑。

圖片資料夾結構示例如下, 資料夾路徑為./images

images/
├── 0001.jpg
├── 0002.jpg
├── 0003.jpg
|...
└── 9999.jpg

檔案列表內容如下:

0001.jpg
0002.jpg
0003.jpg
...
9999.jpg

為了快速走通流程,我們也提供了一個小的示例資料集,執行如下命令下載解壓:

wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/imagenet_raw_demo/imagenet_raw_demo.tar.gz
tar -zxf imagenet_raw_demo.tar.gz
mv imagenet_raw_demo  imagenet_raw

三、模型預訓練

以vit-base為示例。在EasyCV中,使用配置檔案的形式來實現對模型引數、資料輸入及增廣方式、訓練策略的配置,僅透過修改配置檔案中的引數設定,就可以完成實驗配置進行訓練。可以直接下載示例配置檔案。

rm -rf mae_vit_base_patch16_8xb64_1600e.py
wget 

檢視easycv安裝位置

# 檢視easycv安裝位置
import easycv
print(easycv.__file__)

執行訓練命令

python -m torch.distributed.launch --nproc_per_node=1 --master_port=29930 \
/home/pai/lib/python3.6/site-packages/easycv/tools/train.py mae_vit_base_patch16_8xb64_1600e.py --work_dir work_dir/selfsup/jpg/mae --launcher pytorch

四、模型微調

1、對上一步得到的預訓練模型的欄位進行修改,以便用於fintune任務。

import torch 
weight_path = 'work_dir/selfsup/jpg/mae/epoch_5.pth'
state_dict = torch.load(weight_path)['state_dict']
state_dict_out = {}
for key in state_dict:
    state_dict_out[key.replace('encoder.','')] = state_dict[key]
torch.save(state_dict_out,weight_path)

2、下載分類任務示例配置檔案

rm -rf mae_vit_base_patch16_8xb64_100e_lrdecay065_fintune.py
wget 

3、執行訓練命令

python -m torch.distributed.launch --nproc_per_node=1 --master_port=29930 \
/home/pai/lib/python3.6/site-packages/easycv/tools/train.py mae_vit_base_patch16_8xb64_100e_lrdecay065_fintune.py --work_dir work_dir/selfsup/jpg/mae --launcher pytorch

END

後續EasyCV會就SOTA論文復現進行系列的工作介紹,歡迎大家關注和使用,歡迎大家各種維度的反饋和改進建議以及技術討論,同時我們十分歡迎和期待對開源社群建設感興趣的同行一起參與共建。

專案開源地址: github.com/alibaba/Easy

來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/70004426/viewspace-2894985/,如需轉載,請註明出處,否則將追究法律責任。

相關文章