Vision Transformer (ViT) 位置編碼
1. 生成位置編碼
對於每個影像塊(patch),根據其位置生成一個對應的編碼向量。假設每個影像塊的嵌入向量維度為 D,則位置編碼的維度也是 D。
ViT 通常使用可學習的絕對位置編碼,這意味著這些位置編碼是在訓練過程中學到的,並且每個影像塊的位置編碼在訓練開始時是隨機初始化的。
2. 位置編碼矩陣
設有 N 個影像塊(即 N 個輸入向量),每個影像塊對應一個位置編碼向量。將這些編碼向量組織成一個位置編碼矩陣,維度為 N × D。
3. 向輸入新增位置編碼
每個影像塊的嵌入向量與其對應的位置資訊相加:
4. 輸入Transformer
這些新增了位置編碼的向量將作為輸入,傳遞給Transformer模型進行後續處理。
5. 位置編碼的作用
透過將位置編碼與影像塊嵌入向量相加,Transformer能夠區分不同影像塊的位置資訊,進而學習到輸入序列的順序依賴關係,這對於捕捉影像的空間結構資訊至關重要。
6. 程式碼示例(假設使用Python和PyTorch)
import torch
import torch.nn as nn
class VisionTransformer(nn.Module):
def __init__(self, num_patches, embed_dim):
super(VisionTransformer, self).__init__()
# 可學習的位置編碼
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
def forward(self, x):
# x 的維度為 (batch_size, num_patches, embed_dim)
# 新增位置編碼
x = x + self.position_embeddings
return x
在這個示例中,self.position_embeddings
是一個可學習的引數矩陣,其大小為 (1, num_patches, embed_dim)
。在前向傳播時,這個矩陣會與輸入的嵌入向量相加,得到包含位置資訊的輸入。