MobileViT-v1-所有patch內相對位置相同的token之間計算自注意力

iceeci發表於2024-11-15

paper

def my_self(x: torch.Tensor):
    '''
    透過這段程式碼  可以把每張圖片圖片中相對位置相同的若干個tokens放到最後兩個維度
    '''
    # [B, C, H, W] -> [B, C, n_h, p_h, n_w, p_w] 
    # n_h是高度方向上可以分多少個patch  p_h patch的高度  n_w 寬度方向上可以分多少個patch p_w patch的寬度
    x = x.reshape(batch_size, in_channels, num_patch_h, patch_h, num_patch_w, patch_w)
    # [B, C, n_h, p_h, n_w, p_w] -> [B, C, n_h, n_w, p_h, p_w]
    x = x.transpose(3, 4)
    # [B, C, n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
    # num_patches 有多少個patch patch_area 每個patch的大小 patch的大小決定了要分多少組 假如說一個patch內有四個token,那麼面積就是4,所有patch中的第一個token之間互相計算自注意力,所有patch中的第二個token之間互相計算自注意力,第三第四同理
    x = x.reshape(batch_size, in_channels, num_patches, patch_area)
    # [B, C, N, P] -> [B, P, N, C]
    # 一共有多少張圖片,每張圖片中的token分幾個組,每個組內有多少token,每個token的維度有多少 至此,需要互相之間需要計算自注意力的token都已經固定在了最後兩個維度上
    x = x.transpose(1, 3)
    # [B, P, N, C] -> [BP, N, C]
    # 一張圖片有P組,B張圖片就是BP組
    x = x.reshape(batch_size * patch_area, num_patches, -1)

    return x

相關文章