如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

機器之心發表於2021-04-05

機器學習領域,普通的基於學習的模型可以透過大量的資料來訓練得到模型引數,並在某種特定任務上達到很不錯的效果。但是這種學習方法限制了模型在很多應用場景下的可行性:在具體的現實情況中,大量資料的獲取通常是有難度的,小樣本學習機器學習領域目前正在研究的問題之一;另外,模型在訓練過程中只接觸了某一特定任務相關的資料樣本,在面對新任務時,其適應能力和泛化能力較弱。


反觀人類的學習方法,不僅僅是學會了一樣任務,更重要的是具備學習能力,能夠利用以往學習到的知識來指導學習新的任務。如何設計能夠透過少量樣本的訓練來適應新任務的學習模型,是元學習解決的目標問題,實現的方式包括[1]:根據模型評估指標(如模型預測的精確度)學習一種對映關係函式(如排序),基於新任務的表示,找到對應的最優模型引數;學習任務層面的知識,而不僅僅是任務中的具體內容,如任務的分佈、不同任務的特徵表示;學習一個基模型,這個基模型的引數是基於以往多種任務的各個特定模型而得到的,等等。

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

圖 1:什麼是元學習(圖源:http://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2019/Lecture/Meta1%20(v6).pdf)

下面從元學習的工程最佳化、解決區域性最優和過擬合問題、模型解釋性等方面詳細解讀和分析四篇論文。

一、"TaskNorm: Rethinking Batch Normalization for Meta-Learning"

核心:元模型訓練階段的工程最佳化

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


本文是發表於 ICML 2020 中的一篇文章[2],是劍橋大學、Invenia 實驗室和微軟研究院學者共同合作的研究成果,提出了一種適用於元學習在模型訓練時的資料批次標準化方法。

深度學習中網路模型的訓練通常基於梯度下降法,與模型學習效果相關的因素包括了學習步長(學習率)、網路初始化引數,並且當涉及深層網路訓練時,還需要考慮梯度消失的問題。標準化層(normalization layer,NL)的提出,使得增加了標準化層的網路在訓練時,能夠使用更高的學習率,並且能夠降低網路對於初始引數的敏感度,對於深層網路的訓練更加重要。NL 的一般表示為:

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


其中,γ和β為學習的引數,μ和σ是標準化的統計量,a_n 和 a’_n 是輸入和標準化後的輸出。

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

圖 1.1:元學習的訓練集。這是圖片分類的例子,在不同 episode 中,由不同的子類構成不同的分類任務;在相同的 episode 中,支援集和查詢集包含了相同的子類。來自:https://www.jiqizhixin.com/articles/2019-07-01-8

元學習的訓練資料集包括了 context set Dτ(也稱為 support set,支援集)和 target set Tτ(也稱為 query set,查詢集),如圖 1.1 所示。利用這個資料集進行兩個階段的訓練:在內層(inner loop)階段,使用 context set 來更新引數θ,得到特定任務的引數ψ;在外層(outer loop)階段(fφ表示由θ生成ψ的一個過程,可能會引入額外的引數φ),對 target set 中的 input 進行預測,並得到目標損失函式

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


元學習中的分層框架(inner loop 和 outer loop 兩層更新,如圖 1.2 所示),可能會使得傳統的批標準化方式(batch normalization,BN)失效:BN 的使用具有一定的前提條件,獨立同分布 iid 條件,而元學習可能不滿足這個條件,如果直接使用 BN 方法在元學習的網路模型中引入標準化層,可能會導致不理想的元模型效果。

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


作者提出了一種適用於元學習的標準化方式 --- 任務標準化(task normalization,TaskNorm),它能夠提升模型訓練的速率和穩定性,並且能夠保持理想的測試效果;另外,它適用於不同大小的 context set,並沒有受到很大的影響;而且這種標準化方式是非直推式的,因此在測試的時候能夠適用於更多的情景(即更多樣的影像分類任務)。在具體展開介紹 TaskNorm 之前,作者先對元學習的推理方式和幾種常見的標準化方法進行簡單介紹,並且說明了在元學習中對應不同的標準化方法的統計量μ和σ的計算和使用方式。

1.1 方法介紹

  • 直推學習(transductive meta-learning)和非直推學習(non-transductive meta-learning)


對於元學習,作者討論了兩種方式:直推學習和非直推學習。非直推學習的元測試(meta-test)階段,在對測試集(和訓練集類似,也包括了 context set 和 target set)中的單個樣本進行類別預測時,僅僅使用 context set 以及輸入的觀測值。直推學習的元測試階段,對單個樣本進行預測時,不僅需要 context set 和觀測值,還需要測試集中其他樣本的觀測值。作者認為,元學習中的標準化層需要是 * 非直推式 * 的,因為對於直推學習,作者認為它的兩個問題:

1. 對 target set 的分佈敏感。在 outer loop 時,需要用到 target set 的其他樣本,即當前樣本的預測輸出還與其他樣本的輸入相關,因此這種方式相比於非直推學習的泛化性更弱。如果在元測試中使用的 target set 樣本的類別平衡情況和訓練時有差別,那麼模型在測試時的分類效果可能並不會很好。

2. 直推學習利用到了更多的資訊(相當於需要依賴的資訊更多),因此如果將兩種方法直接進行比較是不公平的。


  • 幾種基本的標準化方式以及在元學習中的應用


批標準化(batch normalization,BN)。BN 在訓練階段和測試階段的使用模式是不一樣的。在元訓練(meta-training)階段,均值和方差的計算如下所示:

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


在 BN 中,輸入的通道數不變,對每個通道、使用整個 batch 進行變換,這種標準化的方式沒有涉及不同通道之間的資料交換。更直觀一點,資料集輸入的維度表示為 < B,C,W,H>,那麼標準化計算量μ和σ的維度表示為 <1,C,1,1>。使用所有 batch 計算統計量有一個前提,就是假設了 batch 中的資料服從獨立同分布。在測試階段,使用的均值和方差是訓練集所有資料的均值和方差。

元學習網路中直接使用批標準化(Conventional Usage of Batch Normalization,CBN),會有兩個重要的問題:(1)在元測試階段,使用的是根據元訓練階段資料集計算得到的μ和σ,可以認為這兩個統計量是和元模型等效的引數。然而,訓練時的資料集包括了所有不同的任務,獨立同分布的條件只是在相同任務的資料之間滿足、在不同任務之間不一定滿足。作者將 CBN 應用在 MAML 方法 [3] 中,實驗結果表明了該方法在預測任務上表現並不好。(2)當訓練過程中使用的 batch-size 較小,得到的統計量可能並不準確時,模型的效果也會受到影響。

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

圖 1.3:批標準化(BN),元學習訓練和測試過程中直接使用 BN 的方式。圖源:[2]

基於例項的標準化(Instance-based Normalization)。基於例項的標準化方式是非直推式的,統計量只根據當前例項(如單張圖片)來計算μ和σ,並且不依賴於 context set 資料集的大小

1. 例項標準化(instance normalization,IN)。針對單張圖片的 (H,W) 兩個維度計算統計量(即每一張圖只對 H 和 W 維度進行歸一化),每一張圖都有對應的統計量。該計算方式在元訓練階段(使用訓練集)和元測試階段(使用測試集)是一樣的。

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

圖 1.4:例項標準化(IN),元學習中 context set 和 target set 使用 IN 的方式。圖源[2]

2. 層標準化(layer normalization,LN)。LN 針對圖片單獨進行變換,並考慮到了多個通道的維度。該計算方式在元訓練階段(使用訓練集)和元測試階段(使用測試集)是一樣的。作者在後續提供的實驗結果中,指出 LN 相比於其他標準化方式,在訓練效率方面的表現較不足。

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

圖 1.5:層標準化(LN),元學習中 context set 和 target set 使用 LN 的方式。圖源:[2]

直推批標準化(transductive batch normalization,TBN)。相比於 CBN,TBN 的標準化方式在元測試階段,並不是使用元訓練階段資料集的統計量,而是使用測試資料集(包括 context-set 或者是 target-set)來計算μ和σ。另外,TBN 會根據不同的任務分別計算各自的統計量。

雖然這種方法能夠獲得更好的效果,但是在元測試時,對於 target-set 的標準化處理使用了 target-set 全域性的統計量,相當於測試的資料之間是存在某種資訊交流和利用的,給了更多的先驗資訊,提升測試的準確率。這種方式在資訊利用方面和非直推學習方式並不是對等的,因此不能直接比較 TBN 和其他的非直推方式。

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

圖 1.6:直推式批標準化(TBN)。圖源:[2]

1.2 任務標準化(Task Normalization, TASKNORM)

本質上,找到適用於元學習的標準化方法,關鍵在於找到合適的統計量μ和σ。根據標準化處理對於資料的獨立同分布條件要求,對於元學習來說, μ和σ應該是任務級別的統計量,在一定程度上是融入任務模型引數ψ中。ψ是元模型透過適應 context set 而得到的任務模型的引數,因此在任務模型的推理階段,用到的統計量μ和σ也應該能夠從 context set 計算得到。

結合上述元學習對於標準化統計量的要求,作者首先提出了一種元批次標準化方法( meta-batch normalization,MetaBN)。對於每個任務,在 context set 中計算各自的均值和方差,這個統計量共用於 context set 和 target set;在元訓練階段和元測試階段,是分別根據訓練集中的 context set 和測試集中的 target set 得到各階段的標準化統計量。但是,這種標準化方法仍然會受到 context set 大小的影響:當 context set 的 batch size 較小時,統計量的準確度不夠高,會影響模型的預測效果。

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

圖 1.7:MetaBN 方法和 TaskNorm 方法(包括 TaskNorm-L 和 TaskNorm-I)。圖源:[2]

進一步地,作者保留了 MetaBN 的優點,結合基於例項的標準化方法不依賴資料集大小的特點,提出了本文的核心內容:任務標準化(TASKNORM)。TASKNORM 方法是在 MetaBN 的基礎上,結合了 LN 或者是 IN,可以具體分為 TaskNorm-L 以及 TaskNorm-I 兩種標準化方法:元訓練(元測試)階段,使用訓練集(測試集)的 context set 得到統計量,context set 和 target set 都使用該統計量以及各自的 LN 或者 IN 的加權和,得到最終用於標準化的統計量,其中兩部分統計量的權重由超引數α控制。此時的μ和σ的計算由下式得到:

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


其中,μ_{BN}和σ^2_{BN}是根據 context set 計算的統計量,μ+ 和σ+ 是根據層標準化(LN)或者是例項標準化(IN)得到的非直推式的統計量。這種結合方式的出發點是 * 解決使用少樣本學習時存在的樣本數量相關問題 *:當 context set 的樣本量很少時,僅根據該數量集得到的統計量可能會得到關於該任務的不準確的資料;當結合其他統計量時,有助於提升訓練效率以及模型的預測效果。

作者將權重α定義為一個引數化的變數,它和 context set 大小具有線性關係,表示為:α=sigmoid(scale|Dt| + offset)。其中 Dt 為 context set 元素個數,scale 和 offset 在元訓練階段是可學習的。α和 support set 大小之間存線上性關係式,表示為:α=sigmoid(scale|Dt| + offset)。其中 Dt 為 context set 的大小,scale 和 offset 是在元訓練時學習得到的。

1.3 實驗介紹

作者分別在小規模資料集和大規模資料集上進行少樣本(few-shot)分類任務,對比幾種標準化方法,驗證本文提出的幾個猜想:1)元學習對於標準化方式是比較敏感的;2)直推批標準化(TBN)比非直推批標準化的效果普遍要好;3)考慮了元學習資料集特性的方法如 TaskNorm,MetaBN 以及 RN 的效果,會比 CBN,BRN(batch renormalization),IN,LN 等沒有考慮元學習資料特性的方法要好。在實驗中,作者關注的指標包括模型預測的準確度和訓練效率。

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

表 1.1:基於小資料集(mini imagenet 和 omniglot)的分類實驗,此時僅考慮固定大小的 context set 和 target set。來自:[2]

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

表 1.2:基於大資料集 meta-dataset(包含了 13 個影像分類的資料集)的分類實驗。來自:[2]

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

圖 1.8:不同標準化方法得到的模型準確度和訓練過程的對比圖。圖源:[2]

1.4 小結

本文提出了一種適用於元學習的標準化方法 TASKNORM,基於傳統批標準化方法對統計量的計算進行改進。在計算用於資料標準化的統計量均值μ和方差σ^2 時,該方法考慮了任務內資料的獨立同分布、任務間的資料不滿足獨立同分布條件,context set 大小的影響,以及考慮非直推式的學習方式,從而使得元學習模型能夠應用在更多的場景。透過大量的對比實驗,驗證了使用 TASKNORM 方法能夠提升元學習模型的訓練效率和預測效果。

二、 "Meta-Learning with Warped Gradient Descent" (ICLR2020)

核心:解決基於梯度的元學習方法的引數區域性最優問題

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


本文是發表於 ICLR 2020 中的一篇滿分論文[4],由曼徹斯特大學、Alan 圖靈研究機構和 DeepMind 的研究員提出了元學習中的梯度預處理計算方法。

元學習領域有一個重要的問題,是學會一種更新規則,能夠快速適應新的任務。處理這個問題的方式通常有兩種:訓練網路來產生更新(學習更新方式);或者是學習一個比較好的初始化模型或者是比例因子,應用於基於梯度更新的學習方法(學習和梯度更新相關的因素)。前者容易導致不收斂的效果,後者在少樣本(few-shot 任務中的適應效果可能不太好。

作者結合前面說的兩種方式,提出一種彎曲梯度下降(warped gradient descent)的方法,它主要學習一個引數化預處理矩陣,該矩陣是透過在 task-learner 網路模型的各層之間交叉放置非線性啟用層(即彎曲層,warped layers)而產生。在網路訓練時,這些 warp 層提供了一種更新方式,而它的引數是 meta-learned,在模型訓練過程中是不經過梯度回傳的。

為了驗證這種梯度更新方式的有效性,作者還將這種彎曲梯度方法應用在少樣本學習,標準的有監督學習,持續學習和強化學習等多種設定下進行實驗。

2.1 方法介紹

在基於梯度更新的元學習中,task-learner 元引數的更新規則表示為 U(θ; ξ):= θ-α∇L(θ),初始引數θ_0 的元學習過程可表示為:

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


這類方法由於依賴於梯度更新的軌跡,會存在一些問題:梯度的計算會涉及到較大的計算量;容易受到梯度爆炸或者是梯度消失情況的影響;置信度分配問題。將損失函式 L 抽象成一個曲面,該曲面的情況會影響引數調整的效果,並且此時的引數空間不一定是合理的、不一定適用於不同任務的空間。

針對這幾個問題,作者首先了介紹了一種結合預處理的梯度更新通用規則,表示為:

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


其中,P 表示一個用於預處理梯度的曲面。為了更好地拆分預處理模組的引數和 task-learner 的引數,作者使用了一種更為靈活的結構:在多層網路模型中插入全域性引數化的 warp 層。最為簡單的一種插入方式表示為:

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


h 是網路的隱藏層,w 是插入的 warp 層。在梯度回傳時,對於 warp 層使用的是 Jacobian 矩陣(Dx 和 Dθ)來計算:

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


  • warp-layers 的具體原理和計算流程


如圖 2.1 所示,是 warp 層在 task-learner 中的使用和計算流程。對於 task learner f(x),隱藏層之間(h1 和 h2)嵌入 warp 層(ω1 和ω2):在前向計算時,warp 層相當於啟用層;在任務適應階段(task adaptation)的後向回傳中,warp 層透過 Dω來提供梯度。這就是本文提出的用於網路引數更新的 WarpGrad 方法。

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

圖 2.1:warp 層及 WarpGrad 計算的示意圖。圖源:[4]

透過曲面的圖示來更形象地展示 WarpGrad 起到的作用,如圖 2.2 所示。在理想的 W 空間曲面,能夠產生梯度上的預處理,找出梯度下降的最大方向。

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

圖 2.2:上一行表示 WarpGrad 學習到的元幾何(meta-geometry)P 曲面;下一行表示不同任務的損失函式 W 曲面,其中黑線是普通梯度下降的方向,紫色是利用元幾何得到的梯度下降的方向。圖源:[4]

考慮到 warp 層具有幾何曲面的表示意義,作者提出 warp 層實際上是近似一個矩陣 G,該矩陣是一個正定的矩陣向量,用於度量流形的曲率。

Ω表示 warp-layers 起到的作用,它相當於透過重引數化(ω)來近似於最快的梯度下降方向:

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


在 P - 空間和 W - 空間上的梯度表示為:

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


其中,γ=Ω(θ; φ)表示從 P 空間對映到 W 空間的對映引數,並且

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


P 空間的引數梯度和 W 空間的引數梯度之間的轉換關係如圖 2.3 所示:

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

圖 2.3:P 空間的θ引數梯度等價於 W 空間的γ引數梯度。圖源:[4]

Warp 層引數控制了理想曲面的生成,本質上控制了 task learner 的收斂目標。因此,為了積累所有任務的資訊幫助提升任務適應的過程,warp 層引數是透過元學習來訓練得到的,目標函式表示為:

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


  • Warp-layer 引數的學習方式


作者定義了一個高層的任務τ=(h, L_{meta}, L_{task}),L_{meta}作為元訓練的目標損失函式,用於 warp 引數的適應學習;L_{task}作為任務適應的目標函式,用於θ引數的適應學習。

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


上式對於φ的學習,依賴於 L-task,會涉及到二階梯度的計算。作者進一步做梯度截斷(stop gradient),使得φ的更新只涉及一階梯度。

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

圖 2.4:warpgrad 應用於線上元學習和離線元學習的演算法流程。圖源:[4]

2.2 實驗介紹

在實驗部分,作者在元學習方法 MAML[3]和 Leap[5]方法中引入 WarpGrad 的更新方式,在兩個資料集(miniImageNet 和 tieredImageNet)上做少樣本(few-shot)學習和多樣本(multi-shot)學習,使用了 WarpGrad 方法的元學習模型能夠超過普通元學習模型在分類任務上的準確率

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

圖 2.5:使用 warpgrad 方法進行少樣本學習和多樣本學習的對比實驗。圖源:[4]

作者還驗證了 WarpGrad 方法對模型在不同任務上的泛化能力的作用。如圖 2.6 所示,在不同任務數量的實驗中,Warp-Leap 模型的測試準確率明顯高於其他幾種基準方法。

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

圖 2.6:對比不同方法在不同任務數量實驗中的準確率。圖源:[4]

2.3 小結

本文提出了一種更為泛化的基於梯度的元學習方法 WarpGrad,在網路中引入 warp 層用於預處理原始梯度,該方法的特點包括:(1)WarpGrad 方法本質上是一種基於梯度的更新方式,它的創新之處在於對梯度進行了預處理,所以它也具有梯度下降法的特性,能夠保證訓練模型的收斂;(2)warp 層構造了梯度預處理的分佈,而這個分佈所具有的幾何曲面能夠從任務學習者中分離出來;(3)warp 層的引數是透過任務和對應軌跡來元學習得到的,根據區域性的資訊來獲得任務分佈相關的屬性;(4)相比於用預處理矩陣來直接對梯度進行處理,warp 層在網路模型中同時參與了前向計算和後向梯度回傳,是一種更為有效的學習方法。

三、"Meta-Learning without Memorization"

核心:解決任務層面的過擬合問題

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


本文是由 Google brain 團隊和 UT Austin 學者發表於 ICLR 2020 中的一篇文章[6],它探討了元學習模型的記憶問題並提出解決方法。

在分類任務中,當圖片和類別標籤並不是互斥的(mutually-exclusive)時(如在分類任務 1 中,狗的類別標籤是 2;在分類任務 3 中,狗的類別標籤仍然是 2),分類模型做的事情其實是直接將類別標籤和圖片中的資料特徵對應起來。此時,訓練得到元模型可能 * 無法 * 很好地應用在新的分類任務上:在訓練階段,模型不需要適應訓練資料集、就可以在測試資料集上達到較好地效果;而在推理階段,適應能力較弱的模型,則無法適應新任務的訓練資料集,很難在新任務的測試資料集上達到理想效果。

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

圖 3.1:Meta-learning 的圖模型表示。圖源:[6]

結合元學習的圖模型來進一步理解這個問題的定義。M 是元訓練資料集,包括了在元訓練階段的訓練資料集 D(support set)和測試資料集 D*(query set),θ是元模型引數,φ是特定任務模型引數(task-specific parameters)。q(θ|M)表示基於元訓練資料的元引數分佈,q(φ|D, theta)表示基於任務訓練(per-task training)的任務引數分佈,q(y*|x*, φ, θ)表示預測的分佈:

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


那什麼是記憶問題?就是 y * 的計算,可以獨立於φ和 Di,完全依賴於θ和 x*,即 q(y*|x*, φ, θ)=q(y*|x*, θ)。此時,在測試資料集上的預測結果可以直接根據元模型引數θ來得到,而不需要經過透過適應 D 而得到最佳化後的引數φ來進行預測的過程。

3.1 方法介紹

在本文中,作者給出了記憶問題的數學形式,引入互資訊(mutual information)這個概念:在元學習中的完全記憶,指的是模型在預測 y 時忽略任務訓練資料集 D 的資訊,即 y 和 D 之間的互資訊為 0,表示為 I(y;D|x,θ)=0。為了同時達到低誤差,以及 y * 和 (x*,θ) 之間的低互資訊,需要利用任務訓練資料 D 來做預測,即增大 I(y*;D|x*, θ),從而減少記憶問題。

在本文中,作者提出元正則項(meta-regularizer, MR),基於資訊理論來提供一個通用的、不需要在任務分佈上設定限制條件的方法,解決元學習的記憶問題。更具體地,分別是:啟用項上的元正則化(meta regularization on activations),權重上的元正則化(meta regularization on weights)。

啟用項上的元正則化在上圖中,當給定 theta 時,y * 和 x * 之間的資訊流,包括了 y * 和 x * 之間的直接依賴,以及經過資料集 D 的間接依賴。作者提出,透過引入了一箇中間變數 z*,有 q(ˆy* |x* , φ, θ) = ∫ q(ˆy* |z* , φ, θ)q(z* |x* , θ) dz*,控制 \ hat{y}* 和 x * 之間的資訊流來解決記憶問題。

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

圖 3.2:引入中間變數 z 的元學習的圖模型,。圖源:[6]

此時,為了引導模型有效地利用任務訓練資料 D,增大的互資訊目標變為 I(y*;D|z*, θ),透過如下的推導,等價於增大互資訊 I(x*;y*|θ)和減小 KL 散度 E[D_{KL}(q(z*|x*,θ) || r(z*))]:

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


對於上式左項的互資訊,假如 I(x*;y*|θ)=0,並且存在記憶問題(I(y*;D|x*,θ)=0)時,那麼有 q(y*|x*, θ, D)=q(y*|x*, θ)=q(y*|θ),即預測結果 y * 並不依賴於觀測值 x*,顯然這樣的模型並不會得到理想的預測準確度。因此,最小化損失函式(如式 (1))有助於引導互資訊 I(y*;D|x*,θ) 或者是 I(x*;y*|θ)的最大化,所以在引入中間變數 z * 後,需要做的就是最小化 KL 散度,最終的損失函式表示為:

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


但是,作者在實驗過程中發現這種方法在一些情況下並不能避免記憶問題,並進一步提出了另一種元正則化方法。

權重上的元正則化作者提出,透過懲罰元模型引數,減少元引數所帶有的任務資訊,從而降低模型對於任務的記憶能力、解決記憶問題。對於元引數θ中包含的訓練任務資訊,可以表示為 I(y*1:N,D1:N; θ|x*1:N ),它的上確界有:

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


引數的懲罰項即為最後的 KL 散度,該懲罰項實際上是限制模型引數的複雜度:如果模型需要去記住所有任務的資訊,那麼模型非常複雜;所以限制模型的複雜度,在一定程度上能夠減少元引數包含的任務資訊。但是,作者並沒有完全限制模型引數的複雜度,在實際應用中,仍允許部分模型引數對任務訓練資料進行處理,因此只是在部分引數θ上執行該懲罰項(模型的其他引數則表示為θ~),最後損失函式可以表示為:

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


3.2 實驗介紹

本文分別在分類任務和迴歸任務上進行對比實驗,在這些任務中圖片標籤和圖片資料本身是非互斥的,用於驗證元正則化方法在記憶問題上的有效性。如表 3.1 和 3.2 所示,使用了元正則化(MR)的方法,相比於其他的元學習基準方法,在分類任務和迴歸任務上都能明顯獲得更好的效果。

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

表 3.1:圖片標籤非互斥的迴歸任務(均方差),A 表示使用了啟用項上的元正則化,W 表示使用了權重上的元正則化。來自:[6]

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

表 3.2:圖片標籤非互斥的分類任務(準確率)。來自:[6]

3.3 小結

本文從資訊理論的角度,提出了一種適用於不同的元學習方法的元正則化(MR)方法。該方法可以用在標籤沒有打亂(或者是很難打亂)的任務中,能夠提升元學習方法在更多場景中的適用性和可行性,在一定程度上解決元學習的記憶問題。

四、"Unraveling Meta-Learning: Understanding Feature Representations for Few-Shot Tasks"

核心:探討元模型特徵表示模組的作用(元學習方法的可解釋性)

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


本文是由馬里蘭大學的學者發表於 ICML 2020 中的一篇文章[7]。在少樣本分類(few-shot classification)任務的場景中,元學習方法能夠提供一個快速適應新任務(new tasks)或者是新域(new domains)的基礎模型。然而,很少有工作去探討模型達到不錯效果的深層原因,如元學習方法中特徵提取模組(feature extractor)提取得到的特徵表示的不同之處是什麼。

本文提出,相比於普通學習得到的特徵表示,元學習得到的特徵表示(meta-learned representations)是有區別的、更有助於少樣本學習。使用元學習的特徵表示能夠提升少樣本學習的效果,本文作者歸為兩種不同的機制:(1)固定特徵提取模組引數,只更新(微調)最後的分類層(classification layer)引數。在這種機制下,類別資料點在特徵空間中會更加聚集,那麼在微調時,分類邊界對於提供的樣本會沒那麼敏感。(2)在模型引數空間尋找最優點作為基礎模型,該最優點接近大部分特定任務(task-specific)模型引數的最優點,那麼在面對新的特定任務時,能夠透過幾步的梯度計算,將基礎模型更新為適用於新任務的特定模型。

進一步地,作者分別探討上述兩種機制的作用,定義了幾種正則項,並結合正則項提出了幾種帶正則化的模型訓練方法,透過實驗驗證了相關猜想以及正則化訓練方法的有效性。

4.1 基於特徵聚集的正則化方法

  • 4.1.1 在特徵空間的類別特徵點聚集


作者先討論第一種機制,即微調時固定特徵提取模組、只更新分類層,使用這類機制的元學習方法包括 ProtoNet[8],R2-D2[9]和 MetaOptNet[10]。這類方法能夠達到好的分類效果,猜想是特徵提取模組已經能夠做到很好的特徵區分、從而對於新的分類任務也能夠實現少樣本學習。

特徵點聚集對於少樣本學習的重要性。如下圖所示,當類別的特徵點是分散的、類間相隔較近時,選取少量樣本來訓練分割平面容易導致較大的分割誤差;而當類別的特徵點是聚集的、類間相隔較遠,訓練得到的分割平面準確度較高,分割平面對於樣本選取的依賴較弱。

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

圖 4.1:特徵點聚集對於少樣本訓練分割平面準確度的重要性。圖源:[7]

然後,作者透過對比元學習的 ProtoNet 和傳統訓練的網路模型的特徵提取效果,驗證了元學習方法在特徵點聚集上做得更好,雖然沒有直接證明特徵點聚集對於少樣本學習的必要性,但是為接下來提出的基於特徵點聚集的正則項提供了重要的思路和啟發。

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

圖 4.2:ProtoNet 和經典分類網路在 mini-ImageNet 資料集上提取的特徵進行視覺化(使用 LDA 處理元學習和經典分類器提取的特徵,視覺化對映到二維空間的特徵)。圖源:[7]

本文考慮特徵聚集的評估指標(feature clustering, FC),定義為類內方差和類間方差的佔比。根據 FC 的定義,本文給出了特徵聚集的正則項(feature clustering regularizer, R_fc)定義:

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


其中,f_{θ}(x_i,j)是特徵提取模組 f_{θ}對樣本 x 給出的特徵表示,μ_i 是第 i 類的特徵向量均值,μ是所有資料的特徵向量均值。作者基於 R2-D2 和 MetaOptNet 的網路結構,結合交叉熵損失函式和該正則項,作為傳統的訓練方法的損失函式,在 mini-ImageNet 資料集和 CIFAR-FS 資料集上進行 1-shot 和 5-shot 的實驗,對比使用元訓練的方法和不使用該正則項的傳統訓練方法。

如表 4.1 所示,相比於沒有用 R_fc 訓練的網路效果,使用 R_fc 來訓練網路,能夠和元學習網路達到類似的高分。這進一步說明了使用 R_fc 可以得到類似於元學習網路得到的特徵表示,那麼元學習方法實際上也有做特徵聚集的工作。

更進一步地,作者探討 特徵點聚集分割平面對資料樣本不變性兩者之間的聯絡,提出了超平面方差的正則項(hyperplane variation regularizer):

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


對於兩個類別的特徵點(A 類的 x1 和 x2,B 類的 y1 和 y2),該正則項衡量了不同類別資料點之間的距離向量的差異。當超平面對於資料樣本有較強不變性時,該正則項的值越小。同樣地,作者使用該正則項進行對比實驗,效果和 Rfc 類似,比沒有使用 Rhv 的傳統訓練方法的到的模型的分類效果要好。

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

表 4.1:使用 Rfc 或者是 Rhv 的對比實驗結果。來自:[7]

前面的實驗中,考慮的元學習訓練方式是第一種機制,那對於微調時不會固定特徵提取模組的元學習訓練方式(比如 MAML 方法),情況又是怎樣的呢?作者將 MAML 方法和遷移學習方法對比,發現 MAML 模型的效果並沒有比傳統訓練模型的 feature seperation 效果更優,說明了特徵聚集的提升作用,並不是元學習訓練中會有的普遍現象,而是特定地存在於使用第一種機制的元訓練模型中。於是接下來,作者對於元學習第二種機制的有效性進行了探討和分析。

4.2 權重聚集的正則化方法(weight-clustering regularization)

  • 4.1.2 在引數空間的任務損失函式的最優點聚集


接下來討論沒有固定特徵提取模組的元模型,這類模型的引數能夠很好地適應新任務。對於 Reptile[10],作者提出了一種假設:該方法尋找的模型引數,是接近於很多工的最優點,所以能夠在微調之後在這些任務上達到較好的效果。為了驗證這個猜想,本文將 Reptile 方法表示為類似於一致性最最佳化方法的形式(consensus optimization,使用一項懲罰來促進各個特定任務的模型收斂到共同的引數),最小化的目標函式為:


如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


θ~ 是 task-specific 引數,θ是一致值(實際上是元引數),左項是針對任務 p 的損失函式,右項是距離懲罰項,引導模型引數收斂到一個一致值的附近。雖然 Reptile 實際上並沒有很明顯地使用第二項來得到最優的 task-specific 引數,但是它使用了θ作為 task-specific 模型的初始化引數,隱式地促使θ~ 是在θ附近。

為了驗證引數聚集的作用,作者在原始 reptile 演算法中內部迴圈(inner loop)的損失函式加上如下一項,進而提出權重聚集(Weight Clustering)方法

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程


該項給出了針對某個任務 i 的模型引數θ^~_i 與當前訓練批次所有任務的模型引數θ^~_p 的均值之間的距離。透過將 Reptile 方法結合該正則項,能夠更顯式地促使訓練模型的引數聚集,在 1-shot 和 5-shot 實驗中都能獲得更優於傳統訓練方法、一階 MAML 方法(FOMAML)和原始 Reptile 方法的效果。

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

圖 4.4:使用了引數聚集正則化的 Reptile 演算法(紅色橢圓即為引數聚集相關的正則項)。圖源:[7]

如何基於元學習方法進行有效的模型訓練?四篇論文詳細剖析元模型的學習原理和過程

表 4.2:透過在 mini-ImageNet 上的對比實驗,驗證了增加懲罰項 Ri(即表中 W-Clustering 所示)對於模型效果的提升作用。來自:[7]

4.3 小結

本文對於元學習訓練方法在少樣本學習場景中的有效性進行了深入探討,並提出了元學習得到的資料特徵表示是不同於普通訓練方法得到的資料特徵表示的猜想。本文根據這個猜想設計了具有特徵聚集特性權重聚集特性兩種正則項,並分別應用到遷移學習方法和原始元學習方法中,驗證了正則項對於提升模型效果的作用。

參考文獻

[1] Vanschoren J. "Meta-Learning: A Survey". Arxiv:1810.03548, 2018.
[2] Bronskill, John, Jonathan Gordon, James Requeima, Sebastian Nowozin and R. Turner. "TaskNorm: Rethinking Batch Normalization for Meta-Learning". Proceedings of the 37th International Conference on Machine Learning (ICML), 2020.
[3] Chelsea Finn, Pieter Abbeel, and Sergey Levine. "Model-agnostic meta-learning for fast adaptation of deep networks". Proceedings of the 34th International Conference on Machine Learning (ICML), 2017.
[4] Flennerhag, Sebastian, Andrei A. Rusu, Razvan Pascanu, H. Yin and Raia Hadsell. "Meta-Learning with Warped Gradient Descent". ArXiv:1909.00025, 2020.
[5]Flennerhag, Sebastian, Moreno, Pablo G., Lawrence, Neil D., and Damianou, Andreas. Transferring knowledge across learning processes. In International Conference on Learning Representations, 2019.
[5] Yin, Mingzhang, G. Tucker, M. Zhou, S. Levine and Chelsea Finn. "Meta-Learning without Memorization". ArXiv: 1912.03820, 2020.
[6] Goldblum, Micah, S. Reich, Liam Fowl, Renkun Ni, V. Cherepanova and T. Goldstein. "Unraveling Meta-Learning: Understanding Feature Representations for Few-Shot Tasks" Proceedings of the 37th International Conference on Machine Learning (ICML), 2020.
[7] Snell, J., Swersky, K., and Zemel, R. Prototypical networks for few-shot learning. In Advances in Neural Information Processing Systems, pp. 4077–4087, 2017.
[8] Bertinetto, L., Henriques, J. F., Torr, P. H., and Vedaldi, A. Meta-learning with differentiable closed-form solvers. arXiv preprint arXiv:1805.08136, 2018.
[9] Lee, K., Maji, S., Ravichandran, A., and Soatto, S. Metalearning with differentiable convex optimization. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 10657–10665, 2019.
[10] Nichol, A. and Schulman, J. Reptile: a scalable metalearning algorithm. arXiv preprint arXiv:1803.02999, 2:2, 2018.

分析師介紹:

楊旭韻,工程碩士,主要研究方向是強化學習模仿學習以及元學習。現從事工業機器人相關的技術研究工作,主要負責機器學習演算法落地應用的工作。

關於機器之心全球分析師網路 Synced Global Analyst Network

機器之心全球分析師網路是由機器之心發起的全球性人工智慧專業知識共享網路。在過去的四年裡,已有數百名來自全球各地的 AI 領域專業學生學者、工程專家、業務專家,利用自己的學業工作之餘的閒暇時間,透過線上分享、專欄解讀、知識庫構建、報告發布、評測及專案諮詢等形式與全球 AI 社群共享自己的研究思路、工程經驗及行業洞察等專業知識,並從中獲得了自身的能力成長、經驗積累及職業發展。

相關文章