深度學習(五)之原型網路

段小輝發表於2022-05-07

在本文中,將介紹一些關於小樣本學習的相關知識,以及介紹如何使用pytorch構建一個原型網路(Prototypical Networks[1]),並應用於miniImageNet 資料集。

實驗環境:

pytorch:1.11.0
程式碼地址:https://github.com/xiaohuiduan/deeplearning-study/tree/main/小樣本學習

小樣本學習引入

在這一節將簡要的對小樣本學習(FSL)相關的知識進行介紹。由於我並不是專門研究小樣本的(我學習FSL也只是為了完成我的課程作業),因此,如果本文存在任何問題,歡迎進行批評指正?。

郵箱?:xiaohuiduan@hunnu.edu.cn

首先將模型看成一個黑盒子,不去關注它的內部結構,而是關注其inputoutput

在分類模型[2]中,input是一張貓or狗的圖片,output則為0/1(代表其為貓或者狗;實際上輸出的是兩者的預測概率)。

但是關於上面的模型,存在一個問題,訓練這樣的模型需要大量的資料,根據@petewarden[3]的說法,訓練一個分類圖片的網路,每個類別需要大約1000張。但是,很多場景我們並沒有足夠的資料進行訓練,也就是說資料集的樣本比較少,這時候針對小樣本資料集可以有兩種處理方式[4]

  • 資料增強:比如說對影像進行旋轉,裁剪等等。

  • 資料建模:如使用小樣本學習的方法進行對資料進行建模。

小樣本學習是什麼

以VGG模型預測貓狗分類舉例,模型對於一張新的圖片的預測可以形象的解釋為:

圖片中的動物因為有著尖尖的耳朵,有著較長的鬍鬚,鼻子那一個地方不是很突出,因此,我(VGG)判斷它是一隻貓o(=•ェ•=)m。

但是在小樣本學習中,不是這樣的,小樣本學習對於一張新的圖片的預測可以形象的解釋為:

我手中有2張圖片,圖片A和圖片B。對於新的圖片,我(模型)也不知道他是啥,但是我發現它跟圖片B長得很相似,因此我(模型)判斷這張新的圖片和圖片B是同一個類別。

在上述解釋中,圖片A和B稱之為support sets,而新的圖片稱之為query sets

小樣本的分類模型與傳統的深度學習分類模型(如VGG)有著不同,這裡引用論文[1:1]的一句話:

Few-shot classification is a task in which a classifier must be adapted to accommodate
new classes not seen in training, given only a few examples of each of these classes. A naive approach,
such as re-training the model on the new data, would severely overfit.

也就是說,小樣本分類並不是像VGG一樣,test的資料是以前訓練過的類別,對於小樣本分類來說,進行test的資料是一些新的類,並且這些類別的樣本很少,因此,沒法對其進行re-training,否則會造成過擬合。

小樣本學習方法

小樣本學習可以認為是一個N-way K-shot的分類問題(不確定是不是所有的小樣本分類任務都被認為是N-way K-shot分類問題)。

無論對於測試集還是訓練集,都需要進行如下的劃分,將資料集劃分為兩個部分:左邊DataSet代表資料集,右邊分別代表Support set,右邊代表Query Set。在train或者test資料集中:

  1. 首先對於所有類別,隨機選擇其中N(圖中N=3)個類別(圖中,選擇了類別2,類別3和類別5)。
  2. 在step 1中選擇的類別樣本中,隨機選擇K(圖中K=3)個樣本(綠色的部分),構成Support Set。也就是說,Support Set中擁有K*N個樣本。
  3. 然後在所選擇類別的剩餘樣本中,選擇X(這裡X=1)個樣本(紅色的部分),構成Query Set。也就是說,Query Set中擁有X*N個樣本。

在訓練集中,以上步驟構成的Support Set和Query Set會被input到Model中進行訓練,稱之為一個eposides(相當於mini-batch)。

對於模型來說,其目的則為判斷Query Set中的樣本與哪一個支援集最相似

原型網路(Prototype Network)

原理簡述

Prototype Network的原理很簡單,可以簡單的概括為:將support set中的圖片\(data_1,data_2,\cdots,data_n\)對映到某一個向量空間\(c_1,c_2,\cdots,c_n\);對於Query set中的某一張圖片\(query_i\)使用同一個對映函式,也對映到到向量空間\(x_i\),然後判斷\(x_i\)\(c_1,c_2,\cdots,c_n\)的距離(餘弦距離or歐氏距離),選擇距離最近向量所對應的類別作為\(query_i\)所屬的類別。

示意圖[5]如下所示:

如果瞭解NLP中word2vec的話,會發現,其與word2vec的Embedding思想是很相似的。

演算法流程

演算法流程圖[1:2]如下所示,紅色框和綠色框中的過程已經在前文進行介紹,這裡主要是來介紹一下loss的計算方式。

image-20220507144721234

實際上,loss的計算方式就是一個交叉熵損失函式,pytorch中CrossEntropyLoss的計算方法如下所示,class代表\(x\)實際所屬類別\(x[j]\)代表模型對於\(x\)所屬類別\(j\)的概率預測。

\[\operatorname{loss}(x, \text { class })=-\log \left(\frac{\exp (x[\text { class }])}{\sum_{j} \exp (x[j])}\right)=-x[\text { class }]+\log \left(\sum_{j} \exp (x[j])\right) \]

但是,在演算法流程圖中,大家會發現,其loss計算的正負號剛好與上面公式中的相反,解釋如下:

以歐式距離為例,距離越遠(\(d\)則越大),則代表兩者的相似度越低。如果不加負號的話,進行softmax計算,距離越遠的則predict概率越大,這明顯是錯誤的。因此,加了一個負號之後,距離越遠,進行softmax之後,輸出則越小,predict的概率也變小,這才是合理的。

以上,便是原型網路的演算法流程。

演算法實現

資料集處理

mini-Imagenet是一個專門用於訓練小樣本學習的訓練集,資料集中一共有100個類別,每個類別600張圖片,一共有60000張圖片。資料集可以從mini-ImageNet | Kaggles上面下載。在下載檔案中,一共有4個檔案,一個是資料集圖片的壓縮包,另外3個csv檔案分別代表了訓練、驗證和測試集相關的資訊。其中訓練集有64個類,驗證集16個類,測試集20個類。

csv的部分資料如下所示,filename代表了圖片的名字,label代表了圖片對應的標籤。

因此,可以構建一個label所對應filename的字典

def read_csv(csv_path):
    dict = collections.defaultdict(list)
    df = pd.read_csv(csv_path)
    for index,row in df.iterrows():
        dict[row["label"]].append(row["filename"])
    return dict
train_dict = read_csv(train_csv_path)
val_dict = read_csv(val_csv_path)
test_dict = read_csv(test_csv_path)

同時,構建data與labels的對應關係:

from PIL import Image
import numpy as np
from torchvision import transforms
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

resize_transform = transforms.Resize(84) # 提前對圖片進行縮放,以節省記憶體空間,將最短的邊變成84
def build_data(data_dict):
    datas = []
    labels = []
    label_index = 0
    for label in data_dict.keys(): # 對圖片的標籤進行迭代
        for path in data_dict[label]: # 對標籤對應的檔名進行迭代
            img_path = os.path.join(img_root_dir,path) 
            img = Image.open(img_path) # 讀取檔案
            img = resize_transform(img) # 進行縮放
            datas.append(img) 
            labels.append(label_index)

        label_index += 1
    return {"datas":datas,"labels":labels}

隨機產生Support set 和 Query set

在下面程式碼中,CategoriesSampler的作用是為了產生index,然後供給dataloader使用。

class CategoriesSampler():
    """
        目的是為了隨機產生K_way*(N_support+N_query)個圖片對應的index
    """
    def __init__(self, data, n_batch, K_way, N_per):
        self.n_batch = n_batch
        self.K_way = K_way
        self.N_per = N_per
        labels = np.array(data["labels"]) # [0,0,0,0,1,1,1,1,2,2,2,2……]
        self.index = [] # 記錄label對應的索引位置
        for i in range(max(labels)+1):
            ind = np.argwhere(labels == i).reshape(-1)
            self.index.append(torch.from_numpy(ind))   

    def __len__(self):
        return self.n_batch
    
    def __iter__(self):
        for i_batch in range(self.n_batch):  
            batch = []
            classes = torch.randperm(len(self.index))[:self.K_way] # 隨機選擇K個類別構成support set和query set
            for c in classes:
                l = self.index[c] # 類別c對應的圖片陣列的索引,如 l = [5,6,7,8,9]
                pos = torch.randperm(len(l))[:self.N_per] # 如 pos = [4,1,0]
                batch.append(l[pos]) # 如 l[pos] = [9,6,5]

            batch = torch.stack(batch).reshape(-1)
            yield batch

同時,定義Dataset類,如下:

class MiniImageNet(Dataset):

    def __init__(self, data):

        self.datas = data["datas"]
        self.labels = data["labels"]
        self.transform = transforms.Compose([
            transforms.CenterCrop(84),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.datas)

    def __getitem__(self, i):
        img, label = self.datas[i], self.labels[i]
        return self.transform(img), label

使用示例如下,在MiniImageNet的__getitem__函式中,其引數iCategoriesSampler__iter__函式所產生。

batch_sampler = CategoriesSampler(datas,eval_step,K_way,N_shot+N_query)
data_loader = DataLoader(dataset=data_set, batch_sampler=batch_sampler,
                                            num_workers=16, pin_memory=True)

實際上,上面的程式碼就是為了實現如下圖所示的功能:

Embedding網路構建

前文說到,需要將圖片進行向量化表示,在下面的程式碼中,便可以將一張圖片(shape=\(3\times84\times84\))變成一個1600維的向量。(網路結構來自於論文)

class CNN_Net(nn.Module):
    """
        用於特徵提取
    """

    def __init__(self, input_dim):
        super(CNN_Net, self).__init__()
        
        self.input_dim = input_dim
        def conv_block(in_channel,out_channel):
            return nn.Sequential(
                nn.Conv2d(in_channel, out_channel, 3,padding=1),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(),
                nn.MaxPool2d(2)
            )
        self.encoder = nn.Sequential(
            conv_block(input_dim,64),
            conv_block(64,64),
            conv_block(64,64),
            conv_block(64,64),
        )
    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        return x

損失函式

關於計算損失的關鍵函式如下所示(參考了論文作者的原始碼[6]):

def cal_euc_distance(self, query_z, center,K_way, N_query):
    """
        計算query_z與center的距離
        query_z : (K_way*N_query,z_dim)
        center : (K_way,z_dim)
    """
    center = center.unsqueeze(0).expand(
        K_way*N_query, K_way, self.z_dim)  # (K_way*N_query,K_way,z_dim)
    query_z = query_z.unsqueeze(1).expand(
        K_way*N_query, K_way, self.z_dim)  # (K_way*N_query,K_way,z_dim)

    return torch.pow(query_z-center, 2).sum(2)  # (K_way*N_query,K_way)

def loss_acc(self, query_z, center, K_way, N_query):
    """
        計算loss和acc
        query_z : (K_way*N_query,z_dim)
        center : (K_way,z_dim)
    """
    target_inds = torch.arange(0, K_way).view(K_way, 1).expand(
        K_way, N_query).long().to(self.device) # shape=(K_way, N_query)
    
    distance = self.cal_euc_distance(query_z, center,K_way, N_query)    # (K_way*N_query,K_way) 
    predict_label = torch.argmin(distance, dim=1)  # (K_way*N_query) 預測出來的label

    acc = torch.eq(target_inds.contiguous().view(-1),
                    predict_label).float().mean() # 準確率

    loss = F.log_softmax(-distance, dim=1).view(K_way,
                                                N_query, K_way)  # (K_way,N_query,K_way)
    loss = - \
        loss.gather(dim=2, index=target_inds.unsqueeze(2)).view(-1).mean()
    return loss, acc

def set_forward_loss(self, K_way, N_shot, N_query,sample_datas):
    """
        sample_datas: shape(K_way*(N_shot+N_query),3,84,84)
    """

    z = self.cnn_net(sample_datas) # shape=(K_way*(N_shot+N_query),z_dim) ,將support set和query set都進行向量化表示
    z = z.view(K_way,N_shot+N_query,-1) # shape = (K_way,N_shot+N_query,1600)
    
    support_z = z[:,:N_shot] # support set的向量化表示 shape=(K_way,N_shot,1600)
    query_z = z[:,N_shot:].contiguous().view(K_way*N_query,-1) # Query set的向量化表示 shape=(K_way*N_query,1600)
    
    center = torch.mean(support_z, dim=1) # 計算support set的向量均值,shape=(K_way,1600)
    return self.loss_acc(query_z, center,K_way,N_query)

關於實驗中具體的引數設計,可以參考論文[1:3]或者Github上面的原始碼。在原論文中,對於實驗的設計講得非常清楚。

實驗結果

下面的表格為測試集的acc(當驗證集acc為最大值時測試集所對應的acc):

N-shot=1 N-shot=5
K_way=5 0.4313 0.6684

總結

總的來說,原型網路是一個容易理解的網路模型,思想簡單,易於實現。

References


  1. [1703.05175] Prototypical Networks for Few-shot Learning (arxiv.org) ↩︎ ↩︎ ↩︎ ↩︎

  2. 深度學習(二)之貓狗分類 - 段小輝 - 部落格園 (cnblogs.com) ↩︎

  3. How many images do you need to train a neural network? « Pete Warden's blog ↩︎

  4. 劉穎, 雷研博, 範九倫, 王富平, 公衍超, 田奇. 基於小樣本學習的影像分類技術綜述. 自動化學報, 2021, 47(2): 297−315 ↩︎

  5. 【Pytorch】prototypical network原型網路小樣本影像分類簡述及其實現_Jnchin的部落格-CSDN部落格_原型網路小樣本 ↩︎

  6. jakesnell/prototypical-networks: Code for the NeurIPS 2017 Paper "Prototypical Networks for Few-shot Learning" (github.com) ↩︎

相關文章