vit中的生成分類識別符號介紹

海_纳百川發表於2024-08-10
Vision Transformer (ViT) 分類識別符號

Vision Transformer (ViT) 分類識別符號

1. 初始化分類識別符號

在ViT中,分類識別符號是一個可學習的向量,通常在模型初始化時隨機初始化。這個識別符號的維度與影像塊的嵌入向量維度相同,通常記作 zcls,其大小為 D(與每個影像塊的嵌入向量維度一致)。

2. 與影像塊嵌入一起作為輸入

將這個分類識別符號 zcls 附加在所有影像塊的嵌入向量之前,形成一個擴充套件後的輸入序列。

假設原始影像塊嵌入的序列表示為 [z1, z2, …, zN],其中 N 是影像塊的數量,那麼完整的輸入序列將是:

[zcls, z1, z2, …, zN]

這裡,輸入序列的維度為 (N+1) × D

3. 在Transformer中處理

這個包含分類識別符號的輸入序列會傳遞給Transformer的多層編碼器,經過多層自注意力機制和前饋神經網路的處理。分類識別符號在每一層都會被更新,並最終聚合整個影像的資訊。

4. 提取最終分類識別符號

當輸入序列經過所有Transformer層的處理後,提取出最終的分類識別符號 zclsfinal

這個分類識別符號是一個綜合了整個影像資訊的嵌入向量。

5. 傳遞給分類頭

最終的分類識別符號 zclsfinal 會被傳遞給一個分類頭(通常是一個全連線層)進行影像的分類任務。分類頭輸出的向量用於預測影像屬於哪個類別。

6. 程式碼示例(假設使用Python和PyTorch)

import torch
import torch.nn as nn

class VisionTransformer(nn.Module):
def __init__(self, num_patches, embed_dim, num_classes):
super(VisionTransformer, self).__init__()
# 初始化分類識別符號 (CLS token)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(embed_dim, nhead=8),
num_layers=12
)
self.classifier = nn.Linear(embed_dim, num_classes)

def forward(self, x):
batch_size = x.size(0)
# 複製分類識別符號,使其適應批處理大小
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
# 將分類識別符號新增到影像塊的嵌入向量之前
x = torch.cat((cls_tokens, x), dim=1)
# 新增位置編碼
x = x + self.position_embeddings
# 輸入Transformer
x = self.transformer(x)
# 提取最終的分類識別符號
cls_token_final = x[:, 0, :]
# 傳遞給分類頭進行分類
out = self.classifier(cls_token_final)
return out

相關文章