在本文中,將介紹一些關於小樣本學習的相關知識,以及介紹如何使用pytorch構建一個原型網路(Prototypical Networks[1]),並應用於miniImageNet 資料集。
實驗環境:
pytorch:1.11.0
程式碼地址:https://github.com/xiaohuiduan/deeplearning-study/tree/main/小樣本學習
小樣本學習引入
在這一節將簡要的對小樣本學習(FSL)相關的知識進行介紹。由於我並不是專門研究小樣本的(我學習FSL也只是為了完成我的課程作業),因此,如果本文存在任何問題,歡迎進行批評指正?。
郵箱?:xiaohuiduan@hunnu.edu.cn
首先將模型看成一個黑盒子,不去關注它的內部結構,而是關注其input和output。
在分類模型[2]中,input是一張貓or狗的圖片,output則為0/1(代表其為貓或者狗;實際上輸出的是兩者的預測概率)。
但是關於上面的模型,存在一個問題,訓練這樣的模型需要大量的資料,根據@petewarden[3]的說法,訓練一個分類圖片的網路,每個類別需要大約1000張。但是,很多場景我們並沒有足夠的資料進行訓練,也就是說資料集的樣本比較少,這時候針對小樣本資料集可以有兩種處理方式[4]:
-
資料增強:比如說對影像進行旋轉,裁剪等等。
-
資料建模:如使用小樣本學習的方法進行對資料進行建模。
小樣本學習是什麼
以VGG模型預測貓狗分類舉例,模型對於一張新的圖片的預測可以形象的解釋為:
圖片中的動物因為有著尖尖的耳朵,有著較長的鬍鬚,鼻子那一個地方不是很突出,因此,我(VGG)判斷它是一隻貓o(=•ェ•=)m。
但是在小樣本學習中,不是這樣的,小樣本學習對於一張新的圖片的預測可以形象的解釋為:
我手中有2張圖片,圖片A和圖片B。對於新的圖片,我(模型)也不知道他是啥,但是我發現它跟圖片B長得很相似,因此我(模型)判斷這張新的圖片和圖片B是同一個類別。
在上述解釋中,圖片A和B稱之為
support sets
,而新的圖片稱之為query sets
。
小樣本的分類模型與傳統的深度學習分類模型(如VGG)有著不同,這裡引用論文[1:1]的一句話:
Few-shot classification is a task in which a classifier must be adapted to accommodate
new classes not seen in training, given only a few examples of each of these classes. A naive approach,
such as re-training the model on the new data, would severely overfit.
也就是說,小樣本分類並不是像VGG一樣,test的資料是以前訓練過的類別,對於小樣本分類來說,進行test的資料是一些新的類,並且這些類別的樣本很少,因此,沒法對其進行re-training,否則會造成過擬合。
小樣本學習方法
小樣本學習可以認為是一個N-way K-shot的分類問題(不確定是不是所有的小樣本分類任務都被認為是N-way K-shot分類問題)。
無論對於測試集還是訓練集,都需要進行如下的劃分,將資料集劃分為兩個部分:左邊DataSet代表資料集,右邊分別代表Support set,右邊代表Query Set。在train或者test資料集中:
- 首先對於所有類別,隨機選擇其中N(圖中N=3)個類別(圖中,選擇了類別2,類別3和類別5)。
- 在step 1中選擇的類別樣本中,隨機選擇K(圖中K=3)個樣本(綠色的部分),構成Support Set。也就是說,Support Set中擁有K*N個樣本。
- 然後在所選擇類別的剩餘樣本中,選擇X(這裡X=1)個樣本(紅色的部分),構成Query Set。也就是說,Query Set中擁有X*N個樣本。
在訓練集中,以上步驟構成的Support Set和Query Set會被input到Model中進行訓練,稱之為一個eposides(相當於mini-batch)。
對於模型來說,其目的則為判斷Query Set中的樣本與哪一個支援集最相似。
原型網路(Prototype Network)
原理簡述
Prototype Network的原理很簡單,可以簡單的概括為:將support set中的圖片\(data_1,data_2,\cdots,data_n\)對映到某一個向量空間\(c_1,c_2,\cdots,c_n\);對於Query set中的某一張圖片\(query_i\)使用同一個對映函式,也對映到到向量空間\(x_i\),然後判斷\(x_i\)與\(c_1,c_2,\cdots,c_n\)的距離(餘弦距離or歐氏距離),選擇距離最近向量所對應的類別作為\(query_i\)所屬的類別。
示意圖[5]如下所示:
如果瞭解NLP中word2vec的話,會發現,其與word2vec的Embedding思想是很相似的。
演算法流程
演算法流程圖[1:2]如下所示,紅色框和綠色框中的過程已經在前文進行介紹,這裡主要是來介紹一下loss的計算方式。
實際上,loss的計算方式就是一個交叉熵損失函式,pytorch中CrossEntropyLoss的計算方法如下所示,class代表\(x\)實際所屬類別\(x[j]\)代表模型對於\(x\)所屬類別\(j\)的概率預測。
但是,在演算法流程圖中,大家會發現,其loss計算的正負號剛好與上面公式中的相反,解釋如下:
以歐式距離為例,距離越遠(\(d\)則越大),則代表兩者的相似度越低。如果不加負號的話,進行softmax計算,距離越遠的則predict概率越大,這明顯是錯誤的。因此,加了一個負號之後,距離越遠,進行softmax之後,輸出則越小,predict的概率也變小,這才是合理的。
以上,便是原型網路的演算法流程。
演算法實現
資料集處理
mini-Imagenet是一個專門用於訓練小樣本學習的訓練集,資料集中一共有100個類別,每個類別600張圖片,一共有60000張圖片。資料集可以從mini-ImageNet | Kaggles上面下載。在下載檔案中,一共有4個檔案,一個是資料集圖片的壓縮包,另外3個csv檔案分別代表了訓練、驗證和測試集相關的資訊。其中訓練集有64個類,驗證集16個類,測試集20個類。
csv的部分資料如下所示,filename代表了圖片的名字,label代表了圖片對應的標籤。
因此,可以構建一個label所對應filename的字典
def read_csv(csv_path):
dict = collections.defaultdict(list)
df = pd.read_csv(csv_path)
for index,row in df.iterrows():
dict[row["label"]].append(row["filename"])
return dict
train_dict = read_csv(train_csv_path)
val_dict = read_csv(val_csv_path)
test_dict = read_csv(test_csv_path)
同時,構建data與labels的對應關係:
from PIL import Image
import numpy as np
from torchvision import transforms
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
resize_transform = transforms.Resize(84) # 提前對圖片進行縮放,以節省記憶體空間,將最短的邊變成84
def build_data(data_dict):
datas = []
labels = []
label_index = 0
for label in data_dict.keys(): # 對圖片的標籤進行迭代
for path in data_dict[label]: # 對標籤對應的檔名進行迭代
img_path = os.path.join(img_root_dir,path)
img = Image.open(img_path) # 讀取檔案
img = resize_transform(img) # 進行縮放
datas.append(img)
labels.append(label_index)
label_index += 1
return {"datas":datas,"labels":labels}
隨機產生Support set 和 Query set
在下面程式碼中,CategoriesSampler
的作用是為了產生index,然後供給dataloader使用。
class CategoriesSampler():
"""
目的是為了隨機產生K_way*(N_support+N_query)個圖片對應的index
"""
def __init__(self, data, n_batch, K_way, N_per):
self.n_batch = n_batch
self.K_way = K_way
self.N_per = N_per
labels = np.array(data["labels"]) # [0,0,0,0,1,1,1,1,2,2,2,2……]
self.index = [] # 記錄label對應的索引位置
for i in range(max(labels)+1):
ind = np.argwhere(labels == i).reshape(-1)
self.index.append(torch.from_numpy(ind))
def __len__(self):
return self.n_batch
def __iter__(self):
for i_batch in range(self.n_batch):
batch = []
classes = torch.randperm(len(self.index))[:self.K_way] # 隨機選擇K個類別構成support set和query set
for c in classes:
l = self.index[c] # 類別c對應的圖片陣列的索引,如 l = [5,6,7,8,9]
pos = torch.randperm(len(l))[:self.N_per] # 如 pos = [4,1,0]
batch.append(l[pos]) # 如 l[pos] = [9,6,5]
batch = torch.stack(batch).reshape(-1)
yield batch
同時,定義Dataset類,如下:
class MiniImageNet(Dataset):
def __init__(self, data):
self.datas = data["datas"]
self.labels = data["labels"]
self.transform = transforms.Compose([
transforms.CenterCrop(84),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __len__(self):
return len(self.datas)
def __getitem__(self, i):
img, label = self.datas[i], self.labels[i]
return self.transform(img), label
使用示例如下,在MiniImageNet的__getitem__
函式中,其引數i
由CategoriesSampler
的__iter__
函式所產生。
batch_sampler = CategoriesSampler(datas,eval_step,K_way,N_shot+N_query)
data_loader = DataLoader(dataset=data_set, batch_sampler=batch_sampler,
num_workers=16, pin_memory=True)
實際上,上面的程式碼就是為了實現如下圖所示的功能:
Embedding網路構建
前文說到,需要將圖片進行向量化表示,在下面的程式碼中,便可以將一張圖片(shape=\(3\times84\times84\))變成一個1600維的向量。(網路結構來自於論文)
class CNN_Net(nn.Module):
"""
用於特徵提取
"""
def __init__(self, input_dim):
super(CNN_Net, self).__init__()
self.input_dim = input_dim
def conv_block(in_channel,out_channel):
return nn.Sequential(
nn.Conv2d(in_channel, out_channel, 3,padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.encoder = nn.Sequential(
conv_block(input_dim,64),
conv_block(64,64),
conv_block(64,64),
conv_block(64,64),
)
def forward(self, x):
x = self.encoder(x)
x = x.view(x.size(0), -1)
return x
損失函式
關於計算損失的關鍵函式如下所示(參考了論文作者的原始碼[6]):
def cal_euc_distance(self, query_z, center,K_way, N_query):
"""
計算query_z與center的距離
query_z : (K_way*N_query,z_dim)
center : (K_way,z_dim)
"""
center = center.unsqueeze(0).expand(
K_way*N_query, K_way, self.z_dim) # (K_way*N_query,K_way,z_dim)
query_z = query_z.unsqueeze(1).expand(
K_way*N_query, K_way, self.z_dim) # (K_way*N_query,K_way,z_dim)
return torch.pow(query_z-center, 2).sum(2) # (K_way*N_query,K_way)
def loss_acc(self, query_z, center, K_way, N_query):
"""
計算loss和acc
query_z : (K_way*N_query,z_dim)
center : (K_way,z_dim)
"""
target_inds = torch.arange(0, K_way).view(K_way, 1).expand(
K_way, N_query).long().to(self.device) # shape=(K_way, N_query)
distance = self.cal_euc_distance(query_z, center,K_way, N_query) # (K_way*N_query,K_way)
predict_label = torch.argmin(distance, dim=1) # (K_way*N_query) 預測出來的label
acc = torch.eq(target_inds.contiguous().view(-1),
predict_label).float().mean() # 準確率
loss = F.log_softmax(-distance, dim=1).view(K_way,
N_query, K_way) # (K_way,N_query,K_way)
loss = - \
loss.gather(dim=2, index=target_inds.unsqueeze(2)).view(-1).mean()
return loss, acc
def set_forward_loss(self, K_way, N_shot, N_query,sample_datas):
"""
sample_datas: shape(K_way*(N_shot+N_query),3,84,84)
"""
z = self.cnn_net(sample_datas) # shape=(K_way*(N_shot+N_query),z_dim) ,將support set和query set都進行向量化表示
z = z.view(K_way,N_shot+N_query,-1) # shape = (K_way,N_shot+N_query,1600)
support_z = z[:,:N_shot] # support set的向量化表示 shape=(K_way,N_shot,1600)
query_z = z[:,N_shot:].contiguous().view(K_way*N_query,-1) # Query set的向量化表示 shape=(K_way*N_query,1600)
center = torch.mean(support_z, dim=1) # 計算support set的向量均值,shape=(K_way,1600)
return self.loss_acc(query_z, center,K_way,N_query)
關於實驗中具體的引數設計,可以參考論文[1:3]或者Github上面的原始碼。在原論文中,對於實驗的設計講得非常清楚。
實驗結果
下面的表格為測試集的acc(當驗證集acc為最大值時測試集所對應的acc):
N-shot=1 | N-shot=5 | |
---|---|---|
K_way=5 | 0.4313 | 0.6684 |
總結
總的來說,原型網路是一個容易理解的網路模型,思想簡單,易於實現。
References
[1703.05175] Prototypical Networks for Few-shot Learning (arxiv.org) ↩︎ ↩︎ ↩︎ ↩︎
How many images do you need to train a neural network? « Pete Warden's blog ↩︎
劉穎, 雷研博, 範九倫, 王富平, 公衍超, 田奇. 基於小樣本學習的影像分類技術綜述. 自動化學報, 2021, 47(2): 297−315 ↩︎
【Pytorch】prototypical network原型網路小樣本影像分類簡述及其實現_Jnchin的部落格-CSDN部落格_原型網路小樣本 ↩︎
jakesnell/prototypical-networks: Code for the NeurIPS 2017 Paper "Prototypical Networks for Few-shot Learning" (github.com) ↩︎