從零搭建Pytorch模型教程(三)搭建Transformer網路

CV技術指南(公眾號)發表於2022-04-15


前言 本文介紹了Transformer的基本流程,分塊的兩種實現方式,Position Emebdding的幾種實現方式,Encoder的實現方式,最後分類的兩種方式,以及最重要的資料格式的介紹。

 

本文來自公眾號CV技術指南的技術總結系列

歡迎關注公眾號CV技術指南,專注於計算機視覺的技術總結、最新技術跟蹤、經典論文解讀、CV招聘資訊。

 

在講如何搭建之前,先回顧一下Transformer在計算機視覺中的結構是怎樣的。這裡以最典型的ViT為例。

從零搭建Pytorch模型教程(三)搭建Transformer網路

如圖所示,對於一張影像,先將其分割成NxN個patches,把patches進行Flatten,再通過一個全連線層對映成tokens,對每一個tokens加入位置編碼(position embedding),會隨機初始化一個tokens,concate到通過影像生成的tokens後,再經過transformer的Encoder模組,經過多層Encoder後,取出最後的tokens(即隨機初始化的tokens),再通過全連線層作為分類網路進行分類。

下面我們就根據這個流程來一步一步介紹如何搭建一個Transformer模型。、

 

分塊


目前有兩種方式實現分塊,一種是直接分割,一種是通過卷積核和步長都為patch大小的卷積來分割。

直接分割

直接分割即把影像直接分成多塊。在程式碼實現上需要使用einops這個庫,完成的操作是將(B,C,H,W)的shape調整為(B,(H/P *W/P),P*P*C)。

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

self.to_patch_embedding = nn.Sequential(
           Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
           nn.Linear(patch_dim, dim),
      )
從零搭建Pytorch模型教程(三)搭建Transformer網路

這裡簡單介紹一下Rearrange。

Rearrange用於對張量的維度進行重新變換排序,可用於替換pytorch中的reshape,view,transpose和permute等操作。舉幾個例子

#假設images的shape為[32,200,400,3]
#實現view和reshape的功能
Rearrange(images,'b h w c -> (b h) w c')#shape變為(32*200, 400, 3)
#實現permute的功能
Rearrange(images, 'b h w c -> b c h w')#shape變為(32, 3, 200, 400)
#實現這幾個都很難實現的功能
Rearrange(images, 'b h w c -> (b c w) h')#shape變為(32*3*400, 200)
從零搭建Pytorch模型教程(三)搭建Transformer網路

從這幾個例子看可以看出,Rearrange非常簡單好用,這裡的b, c, h, w都可以理解為表示符號,用來表示操作變化。通過這幾個例子似乎也能理解下面這行程式碼是如何將影像分割的。

Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width)
從零搭建Pytorch模型教程(三)搭建Transformer網路

這裡需要解釋的是,一個括號內的兩個變數相乘表示的是該維度的長度,因此不要把"h"和"w"理解成影像的寬和高。這裡實際上h = H/p1, w = W/p2,代表的是高度上有幾塊,寬度上有幾塊。h和w都不需要賦值,程式碼會自動根據這個表示式計算,b和c也會自動對應到輸入資料的B和C。

後面的"b (h w) (p1 p2 c)"表示了影像分塊後的shape: (B,(H/P *W/P),P*P*C)

這種方式在分塊後還需要通過一層全連線層將分塊的向量對映為tokens。

在ViT中使用的就是這種直接分塊方式。

 

卷積分割

卷積分割比較容易理解,使用卷積核和步長都為patch大小的卷積對影像卷積一次就可以了。

self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
從零搭建Pytorch模型教程(三)搭建Transformer網路

在swin transformer中即使用的是這種卷積分塊方式。在swin transformer中卷積後沒有再加全連線層。

 

Position Embedding


Position Embedding可以分為absolute position embedding和relative position embedding。

在學習最初的transformer時,可能會注意到用的是正餘弦編碼的方式,但這隻適用於語音、文字等1維資料,影像是高度結構化的資料,用正餘弦不合適

在ViT和swin transformer中都是直接隨機初始化一組與tokens同shape的可學習引數,與tokens相加,即完成了absolute position embedding。

在ViT中實現方式:

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
x += self.pos_embedding[:, :(n + 1)]
#之所以是n+1,是因為ViT中選擇隨機初始化一個class token,與分塊得到的tokens拼接。所以patches的數量為num_patches+1。
從零搭建Pytorch模型教程(三)搭建Transformer網路

在swin transformer中的實現方式:

from timm.models.layers import trunc_normal_
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
從零搭建Pytorch模型教程(三)搭建Transformer網路

在TimeSformer中的實現方式:

self.pos_emb = torch.nn.Embedding(num_positions + 1, dim)
從零搭建Pytorch模型教程(三)搭建Transformer網路

以上就是簡單的使用方法,這種方法屬於absolute position embedding。

還有更復雜一點的方法,以後有機會單獨搞一篇文章來介紹。

感興趣的讀者可以先去看看這篇論文《ICCV2021 | Vision Transformer中相對位置編碼的反思與改進》。

 

Encoder


Encoder由Multi-head Self-attention和FeedForward組成。

Multi-head Self-attention

Multi-head Self-attention主要是先把tokens分成q、k、v,再計算q和k的點積,經過softmax後獲得加權值,給v加權,再經過全連線層。

用公式表示如下:

 

所謂Multi-head是指把q、k、v再dim維度上分成head份,公式裡的dk為每個head的維度。

具體程式碼如下:

class Attention(nn.Module):
   def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
       super().__init__()
       inner_dim = dim_head *  heads
       project_out = not (heads == 1 and dim_head == dim)

       self.heads = heads
       self.scale = dim_head ** -0.5
       self.attend = nn.Softmax(dim = -1)
       self.dropout = nn.Dropout(dropout)

       self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
       self.to_out = nn.Sequential(
           nn.Linear(inner_dim, dim),
           nn.Dropout(dropout)
      ) if project_out else nn.Identity()

   def forward(self, x):
       qkv = self.to_qkv(x).chunk(3, dim = -1)
       q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
       dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
       attn = self.attend(dots)
       attn = self.dropout(attn)

       out = torch.matmul(attn, v)
       out = rearrange(out, 'b h n d -> b n (h d)')
       return self.to_out(out)
從零搭建Pytorch模型教程(三)搭建Transformer網路

這裡沒有太多可以解釋的地方,介紹一下q、k、v的來源,由於這是self-attention,因此q=k=v(即tokens),若是普通attention,則k= v,而q是其它的東西,例如可以是另一個尺度的tokens,或視訊領域中的其它幀的tokens。

 

FeedForward

這裡不用多介紹。

class FeedForward(nn.Module):
   def __init__(self, dim, hidden_dim, dropout = 0.):
       super().__init__()
       self.net = nn.Sequential(
           nn.Linear(dim, hidden_dim),
           nn.GELU(),
           nn.Dropout(dropout),
           nn.Linear(hidden_dim, dim),
           nn.Dropout(dropout)
      )
   def forward(self, x):
       return self.net(x)
從零搭建Pytorch模型教程(三)搭建Transformer網路

把上面兩者組合起來就是Encoder了。

class Transformer(nn.Module):
   def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
       super().__init__()
       self.layers = nn.ModuleList([])
       for _ in range(depth):
           self.layers.append(nn.ModuleList([
               PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
               PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
          ]))
   def forward(self, x):
       for attn, ff in self.layers:
           x = attn(x) + x
           x = ff(x) + x
       return x
從零搭建Pytorch模型教程(三)搭建Transformer網路

depth指的是Encoder的數量。PreNorm指的是層歸一化。

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)
從零搭建Pytorch模型教程(三)搭建Transformer網路

 

分類方法


資料通過Encoder後獲得最後的預測向量的方法有兩種典型。在ViT中是隨機初始化一個cls_token,concate到分塊後的token後,經過Encoder後取出cls_token,最後將cls_token通過全連線層對映到最後的預測維度。

#生成cls_token部分
from einops import repeat
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
################################
#分類部分
self.mlp_head = nn.Sequential(
           nn.LayerNorm(dim),
           nn.Linear(dim, num_classes)
      )
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

x = self.to_latent(x)
return self.mlp_head(x)
從零搭建Pytorch模型教程(三)搭建Transformer網路

在swin transformer中,沒有選擇cls_token。而是直接在經過Encoder後將所有資料取了個平均池化,再通過全連線層。

self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

x = self.avgpool(x.transpose(1, 2))  # B C 1
x = torch.flatten(x, 1)
x = self.head(x)
從零搭建Pytorch模型教程(三)搭建Transformer網路

組合以上這些就成了一個完整的模型

class ViT(nn.Module):
   def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
       super().__init__()
       image_height, image_width = pair(image_size)
       patch_height, patch_width = pair(patch_size)

       num_patches = (image_height // patch_height) * (image_width // patch_width)
       patch_dim = channels * patch_height * patch_width
       assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

       self.to_patch_embedding = nn.Sequential(
           Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
           nn.Linear(patch_dim, dim),
      )

       self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
       self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
       self.dropout = nn.Dropout(emb_dropout)
       self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

       self.pool = pool
       self.to_latent = nn.Identity()
       self.mlp_head = nn.Sequential(
           nn.LayerNorm(dim),
           nn.Linear(dim, num_classes)
      )

   def forward(self, img):
       x = self.to_patch_embedding(img)
       b, n, _ = x.shape

       cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
       x = torch.cat((cls_tokens, x), dim=1)
       x += self.pos_embedding[:, :(n + 1)]
       x = self.dropout(x)
       x = self.transformer(x)
       x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

       x = self.to_latent(x)
       return self.mlp_head(x)
從零搭建Pytorch模型教程(三)搭建Transformer網路

 

資料的變換


以上的程式碼都是比較簡單的,整體上最麻煩的地方在於理解資料的變換。

首先輸入的資料為(B, C, H, W),在經過分塊後,變成了(B, n, d)。

在CNN模型中,很好理解(H,W)就是feature map,C是指feature map的數量,那這裡的n,d哪個是通道,哪個是影像特徵?

回顧一下分塊的部分

Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width)
從零搭建Pytorch模型教程(三)搭建Transformer網路

根據這個可以知道n為分塊的數量,d為每一塊的內容。因此,這裡的n相當於CNN模型中的C,而d相當於features。

一般情況下,在Encoder中,我們都是以(B, n, d)的形式。

在swin transformer中這種以卷積的形式分塊,獲得的形式為(B, C, L),然後做了一個transpose得到(B, L, C),這與ViT通過直接分塊方式獲得的形式實際上完全一樣,在Swin transformer中的L即為ViT中的n,而C為ViT中的d。

因此,要注意的是在Multi-head self-attention中,資料的形式是(Batchsize, Channel, Features),分成多個head的是Features。

前面提到,在ViT中會concate一個隨機生成的cls_token,該cls_token的維度即為(B, 1, d)。可以理解為通道數多了個1。

 

以上就是Transformer的模型搭建細節了,整體上比較簡單,大家看完這篇文章後可以找幾篇Transformer的程式碼來理解理解。如ViT, swin transformer, TimeSformer等。

ViT:https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
swin: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
TimeSformer:https://github.com/lucidrains/TimeSformer-pytorch/blob/main/timesformer_pytorch/timesformer_pytorch.py
從零搭建Pytorch模型教程(三)搭建Transformer網路

下一篇我們將介紹如何寫train函式,以及包括設定優化方式,設定學習率,不同層設定不同學習率,解析引數等。

 

歡迎關注公眾號CV技術指南,專注於計算機視覺的技術總結、最新技術跟蹤、經典論文解讀、CV招聘資訊。

CV技術指南建立了一個交流氛圍很不錯的群,除了太偏僻的問題,幾乎有問必答。關注公眾號新增編輯的微訊號可邀請加交流群。

從零搭建Pytorch模型教程(三)搭建Transformer網路​​

 

其它文章

從零搭建Pytorch模型教程(二)搭建網路

從零搭建Pytorch模型教程(一)資料讀取

YOLO系列梳理與複習(二)YOLOv4

YOLO系列梳理(一)YOLOv1-YOLOv3

StyleGAN大彙總 | 全面瞭解SOTA方法、架構新進展

一份熱力圖視覺化程式碼使用教程

一份視覺化特徵圖的程式碼

工業影像異常檢測研究總結(2019-2020)

小樣本學習研究綜述(中科院計算所)

目標檢測中正負樣本區分策略和平衡策略總結

目標檢測中的框位置優化總結

目標檢測、例項分割、多目標跟蹤的Anchor-free應用方法總結

Soft Sampling:探索更有效的取樣策略

如何解決工業缺陷檢測小樣本問題

關於快速學習一項新技術或新領域的一些個人思維習慣與思想總結

相關文章