Swin Transformer
- paper: https://arxiv.org/abs/2103.14030 (ICCV 2021)
- code:https://github.com/microsoft/Swin-Transformer/blob/2622619f70760b60a42b996f5fcbe7c9d2e7ca57/models/swin_transformer.py#L458
- 學習連結:
- https://blog.csdn.net/qq_37541097/article/details/121119988
- https://zhuanlan.zhihu.com/p/626820422 (Multi-Head-Attention的作用到底是什麼?)
- Patch Partition
對圖片進行分塊,相鄰的4x4的畫素為一個Patch,然後在每個Patch中,把每個畫素在通道方向展平,堆疊到一起。特徵圖形狀從[H, W, 3]變成了[H/4, W/4, 48]。
- Linear Embedding
對每個畫素的通道資料進行線性變換。特徵圖形狀從[H/4, W/4, 48]變成了 [H/4, W/4, C]。
-
Swin Transformer Block
-
Windows Multi-head Self-Attention(W-MSA)
為了減少計算量,對特徵圖按照MXM大小劃分成一個個window,單獨對每個windo內部進行self-attention。
-
Shifted Windows Multi-Head Self-Attention(SW-MSA)
W-MSA無法在window與window之間進行資訊傳遞,為了解決這個問題,SW-MSA對偏移的windows內部在進行self-attention。這裡用到了masked MSA來防止不同windows中的資訊亂竄。
-
-
Patch Merging
對特徵圖進行下采樣,H和W都縮小2倍,C增加2倍。Patch Merging會將每個2x2的相鄰畫素劃分為一個patch,然後將每個patch中相同位置的畫素給拼在一起就得到了4個feature map。接著將這四個feature map在深度方向進行concat拼接,然後在透過一個LayerNorm層。最後透過一個全連線層在feature map的深度方向做線性變化,將feature map的深度由C變成C/2。
- Relative Position Bias
公式中的B就是就是Relative Position Bias,論文中的消融實驗驗證了其能帶來明顯的提升。
MSwin
- paper:https://arxiv.org/abs/2203.10638 (ECCV 2022)
- code:https://github.com/DerrickXuNu/v2x-vit/blob/main/v2xvit/models/sub_modules/mswin.py
- MSwin把Swin的序列結構改成了並行,最後用了一個Split-Attention融合了所有分支的特徵
- MSwin論文中指出不需要用SW-MSA,可達到更大的空間互動(猜測是因為並行的設計?)
Deformable Attention
-
paper:https://openaccess.thecvf.com/content/CVPR2022/html/Xia_Vision_Transformer_With_Deformable_Attention_CVPR_2022_paper.html (CVPR 2022)
-
code:https://github.com/LeapLabTHU/DAT
DAT和普通的attention的區別就是,DAT可以匯聚一個自適應的可變感受野資訊,一方面可以提高效率,防止無關資訊的干擾(相比ViT),另一方面可以使得注意模組更加靈活,有效應對多尺度物體的情況(相比Swin)。
-
輸入特徵圖(假設shape = 1, 256, 48, 176)經過一個卷積層生成查詢矩陣q。
-
q透過一個offset network生成偏移量offset(shape = 1, 2, 46, 174),重新排列維度(shape = 1, 46, 174, 2)。
-
生成reference points(shape = 1, 46, 174, 2)。
-
將reference points和offset相加,得到最終的偏移量pos。
-
透過bilinear interpolation,輸入pos,輸出x_sampled(shape = 1, 256, 46, 174)。
-
由x_sampled生成矩陣k和v。
input = torch.rand(1, 256, 48, 176)
dtype, device = input.dtype, input.device
q = self.proj_q(x) # b c h w
# 生成偏移量
offset = conv_offset(q) # torch.Size([1, 2, 46, 174])
offset_range = torch.tensor([1.0 / (46 - 1.0), 1.0 / (174 - 1.0)]).reshape(1, 2, 1, 1)
# 用 tanh 預定義縮放因子防止偏移量變得太大
offset = offset.tanh().mul(offset_range).mul(2) # torch.Size([1, 2, 46, 174])
offset = einops.rearrange(offset, 'b p h w -> b h w p') # torch.Size([1, 46, 174, 2])
# 生成參考點,最後歸一化到[-1,+1]的範圍
reference = _get_ref_points(46, 174, 1, dtype, device) # torch.Size([1, 46, 174, 2])
pos = offset + reference
# torch.Size([1, 256, 46, 174])
x_sampled = F.grid_sample(
input=input,
grid=pos[..., (1, 0)], # y, x -> x, y
mode='bilinear', align_corners=True) # B, C, Hg, Wg
MSwin + Deformable Attention
???