Vision Transformer (ViT) 分類識別符號
1. 初始化分類識別符號
在ViT中,分類識別符號是一個可學習的向量,通常在模型初始化時隨機初始化。這個識別符號的維度與影像塊的嵌入向量維度相同,通常記作 zcls,其大小為 D(與每個影像塊的嵌入向量維度一致)。
2. 與影像塊嵌入一起作為輸入
將這個分類識別符號 zcls 附加在所有影像塊的嵌入向量之前,形成一個擴充套件後的輸入序列。
假設原始影像塊嵌入的序列表示為 [z1, z2, …, zN],其中 N 是影像塊的數量,那麼完整的輸入序列將是:
[zcls, z1, z2, …, zN]
這裡,輸入序列的維度為 (N+1) × D。
3. 在Transformer中處理
這個包含分類識別符號的輸入序列會傳遞給Transformer的多層編碼器,經過多層自注意力機制和前饋神經網路的處理。分類識別符號在每一層都會被更新,並最終聚合整個影像的資訊。
4. 提取最終分類識別符號
當輸入序列經過所有Transformer層的處理後,提取出最終的分類識別符號 zclsfinal。
這個分類識別符號是一個綜合了整個影像資訊的嵌入向量。
5. 傳遞給分類頭
最終的分類識別符號 zclsfinal 會被傳遞給一個分類頭(通常是一個全連線層)進行影像的分類任務。分類頭輸出的向量用於預測影像屬於哪個類別。
6. 程式碼示例(假設使用Python和PyTorch)
import torch
import torch.nn as nn
class VisionTransformer(nn.Module):
def __init__(self, num_patches, embed_dim, num_classes):
super(VisionTransformer, self).__init__()
# 初始化分類識別符號 (CLS token)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(embed_dim, nhead=8),
num_layers=12
)
self.classifier = nn.Linear(embed_dim, num_classes)
def forward(self, x):
batch_size = x.size(0)
# 複製分類識別符號,使其適應批處理大小
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
# 將分類識別符號新增到影像塊的嵌入向量之前
x = torch.cat((cls_tokens, x), dim=1)
# 新增位置編碼
x = x + self.position_embeddings
# 輸入Transformer
x = self.transformer(x)
# 提取最終的分類識別符號
cls_token_final = x[:, 0, :]
# 傳遞給分類頭進行分類
out = self.classifier(cls_token_final)
return out