【手搓模型】親手實現 Vision Transformer

睡晚不猿序程發表於2023-03-17

?前言

  • ?部落格主頁:?睡晚不猿序程?
  • ⌚首發時間:2023.3.17,首發於部落格園
  • ⏰最近更新時間:2023.3.17
  • ?本文由 睡晚不猿序程 原創
  • ?作者是蒻蒟本蒟,如果文章裡有任何錯誤或者表述不清,請 tt 我,萬分感謝!orz

相關文章目錄 :無


目錄

1. 內容簡介

最近在準備使用 Transformer 系列作為 backbone 完成自己的任務,感覺自己打程式碼的次數也比較少,正好直接用別人寫的程式碼進行訓練的同時,自己看著 ViT 的論文以及別人實現的程式碼自己實現一下 ViT

感覺 ViT 相對來說實現還是比較簡單的,也算是對自己程式碼能力的一次練習吧,好的,我們接下來開始手撕 ViT


2. Vision Transformer 總覽

ViT

我這裡預設大家都理解了 Transformer 的構造了!如果有需要我可以再發一下 Transformer 相關的內容

ViT 的總體架構和 Transformer 一致,因為它的目標就是希望保證 Transformer 的總體架構不變,並將其應用到 CV 任務中,它可以分為以下幾個部分:

  1. 預處理

    包括以下幾個步驟:

    1. 劃分 patch
    2. 線性嵌入
    3. 新增 CLS Token
    4. 新增位置編碼
  2. 使用 Transformer Block 進行處理

  3. MLP 分類頭基於 CLS Token 進行分類

上面講述的是大框架,接下來我們深入 ViT 的Transformer Block 去看一下和原本的 Transformer 有什麼區別

Transformer Block

image-20230316223251693

和 Transformer 基本一致,但是使用的是 Pre-Norm,也就是先進行 LayerNorm 然後再做自注意力/MLP,而 Transformer 選擇的是 Pose-Norm,也就是先做自注意力/MLP 然後再做 LayerNorm

Pre-Norm 和 Pose-Norm 各有優劣:

  • Pre-Norm 可以不使用 warmup,訓練更簡單
  • Pose-Norm 必須使用 warmup 以及其他技術,訓練較難,但是完成預訓練後泛化能力更好

ViT 選擇了 Pre-Norm,所以訓練更為簡單

3. 手撕 Transformer

接下來我們一部分一部分的來構建 ViT,由一個個元件最後拼合成 ViT

3.1 預處理部分

image-20230316225110799

這一部分我們將會構建:

  1. 劃分 patch
  2. 線性嵌入
  3. 插入 CLS Token
  4. 嵌入位置編碼資訊

我們先把整個部分的程式碼放在這裡,之後我們再詳細講解

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange


class pre_proces(nn.Module):
    def __init__(self, image_size, patch_size, patch_dim, dim):
        super().__init__()
        self.patch_size = patch_size
        self.dim = dim
        self.patch_num = (image_size//patch_size)**2
        self.linear_embedding = nn.Linear(patch_dim, dim)
        self.position_embedding = nn.Parameter(torch.randn(1, self.patch_num+1, self.dim))  # 使用廣播
        self.CLS_token = nn.Parameter(torch.randn(1, 1, self.dim))  # 別忘了維度要和 (B,L,C) 對齊

    def forward(self, x):
        x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)  # (B,L,C)
        x = self.linear_embedding(x)
        b, l, c = x.shape   # 獲取 token 的形狀 (B,L,c)
        CLS_token = repeat(self.CLS_token, '1 1 d -> b 1 d', b=b)  # 位置編碼複製 B 份
        x = torch.concat((CLS_token, x), dim=1)
        x = x+self.position_embedding
        return x

可以先大概瀏覽一下,也不是很難看懂啦!

3.1.1 patch 劃分

x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)  # (B,L,C)

我們直接使用 einops 庫中的 rearrange 函式來劃分 patch,我們輸入的 x 的陣列表示為 (B,C,H,W),我們要把它劃分成 (B,L,C),其中 \(L=\frac W{W_p}\times \frac H{H_p}\),也就是 patch 的個數,最後 \(C=W_p\times H_p\times channels\)

這個函式就把原先的 (B,C,H,W) 表示方式拆開了,很輕易的就能夠做到我們想要的 patch 劃分,注意 h 和 p1 和 p2 的順序不能亂

3.1.2 線性嵌入

首先我們要先定義一個全連線層

self.linear_embedding = nn.Linear(patch_dim, dim)

使用這個函式將 patch 對映到 Transformer 處理的維度

x = self.linear_embedding(x)

接著使用這個函式來執行線性嵌入,將其對映到維度 dim

3.1.3 插入 CLS Token

CLS Token 是最後分類頭處理的依據,這個思想好像是來源於 BERT,可以看作是一種 池化 方式,CLS Token 在 Transformer 中會和其他元素進行互動,最後的輸出時可以認為它擁有了所有 patch 資訊,如果不使用 CLS Token 也可以選擇平均池化等方式來進行分類

首先我們要定義 CLS Token,他是一個可學習的向量,所以需要註冊為 nn.Parameter ,其維度和 Transformer 處理維度一致,以便於後面進行級聯

self.CLS_token = nn.Parameter(torch.randn(1, 1, self.dim))  # 別忘了維度要和 (B,L,C) 對齊

我們得到了一個大小為 (1,1,dim) 的向量,但是我們的輸入的是一個 batch,所以我們要對他進行復制,我們可以使用 einops 庫中的 repeat 函式來進行復制,然後再進行級聯

CLS_token = repeat(self.CLS_token, '1 1 d -> b 1 d', b=b)  # 位置編碼複製 B 份
x = torch.concat((CLS_token, x), dim=1)

其中 b 是 batch 大小

可以發現 einops 庫可以很方便的進行矩陣的重排

3.1.4 嵌入位置資訊

ViT 使用可學習的位置編碼,而 Transformer 使用的是 sin/cos 函式進行編碼,使用可學習位置編碼顯然更為方便

self.position_embedding = nn.Parameter(torch.randn(1, self.patch_num+1, self.dim))  # 使用廣播

可學習的引數一定要註冊為 nn.Parameter

向量的個數為 patch 的個數+1,因為因為在頭部還加上了一個 CLS Token 呢,最後使用加法進行位置嵌入

x = x+self.position_embedding

好了每個模組都講解完成,我們將他拼合

class pre_proces(nn.Module):
    def __init__(self, image_size, patch_size, patch_dim, dim):
        super().__init__()
        self.patch_size = patch_size	# patch 的大小
        self.dim = dim	# Transformer 使用的維度,Transformer 的特性是輸入輸出大小不變
        self.patch_num = (image_size//patch_size)**2	# patch 的個數
        self.linear_embedding = nn.Linear(patch_dim, dim)	# 線性嵌入層
        self.position_embedding = nn.Parameter(torch.randn(1, self.patch_num+1, self.dim))  # 使用廣播
        self.CLS_token = nn.Parameter(torch.randn(1, 1, self.dim))  # 別忘了維度要和 (B,L,C) 對齊

    def forward(self, x):
        x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)  # (B,L,C)
        x = self.linear_embedding(x)	# 線性嵌入
        b, l, c = x.shape   # 獲取 token 的形狀 (B,L,c)
        CLS_token = repeat(self.CLS_token, '1 1 d -> b 1 d', b=b)  # 位置編碼複製 B 份
        x = torch.concat((CLS_token, x), dim=1)	# 級聯 CLS Token
        x = x+self.position_embedding	# 位置嵌入
        return x

3.2 Transformer

image-20230316232801655

這一部分將會是我們的重點,建議大家手推一下自注意力計算,不然可能會有點難理解

3.2.1 多頭自注意力

首先來回憶一下自注意力公式:

\[Output=softmax(\frac{QK^T}{\sqrt{D_k}})V \]

輸入透過 \(W_q,W_k,W_v\) 對映為 QKV,然後經過上述計算得到輸出,多頭注意力就是使用多個對映權重進行對映,然後最後拼接成為一個大的矩陣,再使用一個對映矩陣對映為輸出函式

還是一樣,我們先把整個程式碼放上來,我們接著在逐行講解

class Multihead_self_attention(nn.Module):
    def __init__(self, heads, head_dim, dim):
        super().__init__()
        self.head_dim = head_dim    # 每一個注意力頭的維度
        self.heads = heads  # 注意力頭個數
        self.inner_dim = self.heads*self.head_dim  # 多頭自注意力最後的輸出維度
        self.scale = self.head_dim**-0.5   # 正則化係數
        self.to_qkv = nn.Linear(dim, self.inner_dim*3)  # 生成 qkv,每一個矩陣的維度和由自注意力頭的維度以及頭的個數決定
        self.to_output = nn.Linear(self.inner_dim, dim)
        self.norm = nn.LayerNorm(dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = self.norm(x)    # PreNorm
        qkv = self.to_qkv(x).chunk(3, dim=-1)  # 劃分 QKV,返回一個列表,其中就包含了 QKV
        Q, K, V = map(lambda t: rearrange(t, 'b l (h dim) -> b h l dim', dim=self.head_dim), qkv)
        K_T = K.transpose(-1, -2)
        att_score = Q@K_T*self.scale
        att = self.softmax(att_score)
        out = att@V   # (B,H,L,dim)
        out = rearrange(out, 'b h l dim -> b l (h dim)')  # 拼接
        output = self.to_output(out)
        return output

我們先用圖來表示一下多頭自注意力,也就是用多個不同的權重來對映,然後再計算自注意力,這樣就得到了多組的輸出,最後再進行拼接,使用一個大的矩陣來把多頭自注意力輸出對映回輸入大小

image-20230317203621599

我們如何構造這多個權重矩陣來進行矩陣運算更快呢?答案是——寫成一個線性對映,然後再透過矩陣重排來得到多組 QKV,然後計算自注意力,我們來看圖:

image-20230317220427599

首先輸入是一個 (N,dim) 的張量,我們可以把多頭的對映橫著排列變成一個大矩陣,這樣使用一次矩陣運算就可以得到多個輸出

我這裡假設了四個頭,並且每一個頭的維度是 2

經過對映,我們得到了一個 \((N,heads\times head\_dim)\) 大小的張量,這時候我們對其重新排列,形成 \((heads,N,head\_dim)\) 大小的張量,這樣就把每一個頭給分離出來了

接著就是做自注意力,我們現在的張量當作 Q,K 就需要進行轉置,其張量大小是 \((heads,head\_dim,N)\) ,二者進行相乘,得到的輸出為 \((heads,N,N)\),這就是我們的注意力得分,經過 softmax 就可以和 V 相乘了

這裡省略了 softmax,重點看矩陣的維度變化

image-20230317221013326

計算自注意力輸出,就是和 V 相乘,V 的張量大小為 \((heads,N,head\_dim)\) ,最後得到輸出大小為 \((heads,N,head\_dim)\)

image-20230317221405554

我們把上一步的張量 \((heads,N,head\_dim)\) 重排為 \((N,heads\times head\_dim)\),然後使用一個大小為 \((heads\times\_dim,dim)\) 的矩陣對映回和輸入相同的大小,這樣多頭自注意力就計算完成了

大家可以像我一樣把過程給寫出來,可以清晰非常多,接下來我們再看一下程式碼實現:

首先定義我們需要的對映矩陣以及 softmax 函式以及 layernorm 函式

self.head_dim = head_dim    # 每一個注意力頭的維度
self.heads = heads  # 注意力頭個數
self.inner_dim = self.heads*self.head_dim  # 多頭自注意力輸出級聯後的輸出維度
self.scale = self.head_dim**-0.5   # 正則化係數
self.to_qkv = nn.Linear(dim, self.inner_dim*3)  # 生成 qkv,每一個矩陣的維度由自注意力頭的維度以及頭的個數決定
self.to_output = nn.Linear(self.inner_dim, dim)	# 輸出對映矩陣
self.norm = nn.LayerNorm(dim)	# layerNorm
self.softmax = nn.Softmax(dim=-1)	# softmax

有了這些,我們可以開始 MHSA 的計算

    def forward(self, x):
        x = self.norm(x)    # PreNorm
        qkv = self.to_qkv(x).chunk(3, dim=-1)  # 按照最後一個維度均分為三分,也就是劃分 QKV,返回一個列表,其中就包含了 QKV
        Q, K, V = map(lambda t: rearrange(t, 'b l (h dim) -> b h l dim', dim=self.head_dim), qkv)	# 對 QKV 的多頭對映進行拆分,得到(B,head,L,head_dim)
        K_T = K.transpose(-1, -2)	# K 進行轉置,用於計算自注意力
        att_score = Q@K_T*self.scale	# 計算自注意力得分
        att = self.softmax(att_score)	# softmax
        out = att@V   # (B,H,L,dim); 自注意力輸出
        out = rearrange(out, 'b h l dim -> b l (h dim)')  # 拼接
        output = self.to_output(out)	#輸出對映
        return output

上面的部分進行組合

class Multihead_self_attention(nn.Module):
    def __init__(self, heads, head_dim, dim):
        super().__init__()
        self.head_dim = head_dim    # 每一個注意力頭的維度
        self.heads = heads  # 注意力頭個數
        self.inner_dim = self.heads*self.head_dim  # 多頭自注意力最後的輸出維度
        self.scale = self.head_dim**-0.5   # 正則化係數
        self.to_qkv = nn.Linear(dim, self.inner_dim*3)  # 生成 qkv,每一個矩陣的維度和由自注意力頭的維度以及頭的個數決定
        self.to_output = nn.Linear(self.inner_dim, dim)
        self.norm = nn.LayerNorm(dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = self.norm(x)    # PreNorm
        qkv = self.to_qkv(x).chunk(3, dim=-1)  # 劃分 QKV,返回一個列表,其中就包含了 QKV
        Q, K, V = map(lambda t: rearrange(t, 'b l (h dim) -> b h l dim', dim=self.head_dim), qkv)
        K_T = K.transpose(-1, -2)
        att_score = Q@K_T*self.scale
        att = self.softmax(att_score)
        out = att@V   # (B,H,L,dim)
        out = rearrange(out, 'b h l dim -> b l (h dim)')  # 拼接
        output = self.to_output(out)
        return output

3.2.2 FeedForward

構建後面的 FeedForward 模組,這個模組就是一個 MLP,中間夾著非線性啟用,所以我們直接看程式碼吧

class FeedForward(nn.Module):
    def __init__(self, dim, mlp_dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, mlp_dim)
        self.fc2 = nn.Linear(mlp_dim, dim)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        x = self.norm(x)
        x = F.gelu(self.fc1(x))
        x = self.fc2(x)
        return x

3.2.3 Transformer Block

有了 MHSA 以及 FeedForward,我們可以來構建 Transformer Block,這是 Transformer 的基本單元,只需要把我們構建的模組進行組裝,然後新增殘差連線即可,不會很難

class Transformer_block(nn.Module):
    def __init__(self, dim, heads, head_dim, mlp_dim):
        super().__init__()
        self.MHA = Multihead_self_attention(heads=heads, head_dim=head_dim, dim=dim)
        self.FeedForward = FeedForward(dim=dim, mlp_dim=mlp_dim)

    def forward(self, x):
        x = self.MHA(x)+x
        x = self.FeedForward(x)+x
        return x

新增了一個引數 depth ,用來定義 Transformer 的層數

ViT

祝賀大家,走到最後一步啦!我們把上面的東西組裝起來,構建 ViT 吧

class ViT(nn.Module):
    def __init__(self, image_size, channels, patch_size, dim, heads, head_dim, mlp_dim, depth, num_class):
        super().__init__()
        self.to_patch_embedding = pre_proces(image_size=image_size, patch_size=patch_size, patch_dim=channels*patch_size**2, dim=dim)
        self.transformer = Transformer(dim=dim, heads=heads, head_dim=head_dim, mlp_dim=mlp_dim, depth=depth)
        self.MLP_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_class)
        )
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        token = self.to_patch_embedding(x)
        output = self.transformer(token)
        CLS_token = output[:, 0, :]	# 提取出 CLS Token
        out = self.softmax(self.MLP_head(CLS_token))
        return out

總結

這裡我們手動實現了 ViT 的構建,不知道大家有沒有對 Transformer 的架構有更深入的理解呢?我也是動手實現了才理解其各種細節,剛開始覺得自己不可能實現,但是最後還是成功的,感覺好開心:D

參考

[1] lucidrains/vit-pytorch

[2] 全網最強ViT (Vision Transformer)原理及程式碼解析

[3] Dosovitskiy A, Beyer L, Kolesnikov A, et al. An image is worth 16x16 words: Transformers for image recognition at scale[J]. arXiv preprint arXiv:2010.11929, 2020.

相關文章