面向超網路的連續學習:新演算法讓人工智慧不再“災難性遺忘”

AIBigbull2050發表於2020-01-14
2020-01-14 12:15:31
面向超網路的連續學習:新演算法讓人工智慧不再“災難性遺忘”

作者 | 十、年

編輯 |  Camel

人腦顯然是人工智慧追求的最高標準。

畢竟人腦使得人類擁有了連續學習的能力以及情境依賴學習的能力。

這種可以在新的環境中不斷吸收新的知識和根據不同的環境靈活調整自己的行為的能力,也正是深度學習系統與人腦相差甚遠的重要原因。

想讓傳統深度學習系統獲得連續學習能力,最重要的是克服人工神經網路會出現的“災難性遺忘”問題,即一旦使用新的資料集去訓練已有的模型,該模型將會失去對原資料集識別的能力。

換句話說就是:讓神經網路在學習新知識的同時保留舊知識。

前段時間,來自蘇黎世聯邦理工學院以及蘇黎世大學的研究團隊發表了一篇名為《超網路的連續學習》(Continual learning with hypernetworks)的研究。提出了 任務條件化的超網路(基於任務屬性生成目標模型權重的網路)。該方法能夠有效克服災難性的遺忘問題。

面向超網路的連續學習:新演算法讓人工智慧不再“災難性遺忘”

具體來說,該方法能夠幫助在針對多個任務訓練網路時,有效處理災難性的遺忘問題。除了在標準持續學習基準測試中獲得最先進的效能外,長期的附加實驗任務序列顯示,任務條件超網路(task-conditioned hypernetworks )表現出非常大的保留先前記憶的能力。

hypernetworks

在蘇黎世聯邦理工學院以及蘇黎世大學的這項工作中,最重要的是對超網路(hypernetworks)的應用,在介紹超網路的連續學習之前,我們先對超網路做一下介紹。

hyperNetwork是一個非常有名的網路, 簡單說就是用一個網路來生成另外一個網路的引數 工作原理是:用一個hypernetwork輸入訓練集資料,然後輸出對應模型的引數,最好的輸出是這些引數能夠使得在測試資料集上取得好的效果。簡單來說hypernetwork其實就是一個meta network。

傳統的做法是用訓練集直接訓練這個模型,但是如果使用hypernetwork則不用訓練,拋棄反向傳播與梯度下降,直接輸出引數,這等價於hypernetwork學會了如何學習影像識別。

面向超網路的連續學習:新演算法讓人工智慧不再“災難性遺忘”

論文下載見文末

在《hypernetwork》這篇論文中,作者使用 hyperNetwork 生成 RNN 的權重,發現能為 LSTM 生成非共享權重,並在字元級語言建模、手寫字元生成和神經機器翻譯等序列建模任務上實現最先進的結果。超網路採用一組包含有關權重結構的資訊的輸入,並生成該層的權重,如下圖所示。

面向超網路的連續學習:新演算法讓人工智慧不再“災難性遺忘”

超網路生成前饋網路的權重:黑色連線和引數與主網路相關聯,而橙色連線和引數與超網路相關聯。

超網路的連續學習模型

在整個工作中,首先假設輸入的資料{X (1),......X (T)}是可以被儲存的,並能夠使用輸入的資料計算Θ (T −1)。另外,可以將未使用的資料和已經使用過資料進行混合來避免遺忘。假設F(X,Θ)是模型,那麼混合後的資料集為{(X(1),Yˆ (1)),。。。,(X (T−1),Yˆ (T−1)),(X (T),Yˆ (T))},其中其中Yˆ(T)是由模型f(.,Θ (t−1))生成的一組合成目標。

然而儲存資料顯然違背了連續學習的原則,所以在在論文中,作者提出了一種新的元模型fh(e (t)h)做為解決方案,新的解決方案能夠將 關注點從單個的資料輸入輸出轉向引數集{Θ (T)},並實現非儲存的要求。這個元模型稱為任務條件超網路,主要思想是建立任務e (t)和權重Θ的對映關係,能夠降維處理資料集的儲存, 大大節省記憶體。

在《超網路的連續學習》這篇論文中,模型部分主要有3個部分,第一部分是任務條件超網路。首先,超網路會將目標模型引數化,即不是直接學習特定模型的引數,而是學習元模型的引數,從而元模型會輸出超網路的權重,也就是說超網路只是權重生成器。

面向超網路的連續學習:新演算法讓人工智慧不再“災難性遺忘”

圖a:正則化後的超網路生成目標網路權重引數;圖b:迭代地使用較小的組塊超網路產生目標網路權重。

然後利用帶有超網路的連續學習輸出正則化。在論文中,作者使用兩步最佳化過程來引入記憶保持型超網路輸出約束。首先,計算∆Θh(∆Θh的計算原則基於最佳化器的選擇,本文中作者使用Adam),即找到能夠最小化損失函式的引數。損失函式表示式如下圖所示:

面向超網路的連續學習:新演算法讓人工智慧不再“災難性遺忘”

注:Θ h是模型學習之前的超網路的引數;∆Θ h為外生變數;βoutput是用來控制正則化強度的引數。

然後考慮模型的e (t),它就像Θ h一樣。在演算法的每一個學習步驟中,需要及時更新,並使損失函式最小化。在學習任務之後,儲存最終e (t)並將其新增到集合{e (T)}。

模型的第二部分是用分塊的超網路進行模型壓縮。超網路產生目標神經網路的整個權重集。然而,超網路可以迭代呼叫,在每一步只需分塊填充目標模型中的一部分。這表明允許應用較小的可重複使用的超網路。有趣的是,利用分塊超網路可以在壓縮狀態下解決任務,其中學習引數(超網路的那些)的數量實際上小於目標網路引數的數量。

為了避免在目標網路的各個分割槽之間引入權重共享,作者引入塊嵌入的集合{C} 作為超網路的附加輸入。因此,目標網路引數的全集Θ_trgt=[f h(e,c 1),,,f h(e,C Nc)]是透過在C上迭代而產生的,在這過程中保持e不變。這樣,超網路可以每個塊上產生截然不同的權重。另外,為了簡化訓練過程,作者對所有任務使用一組共享的塊嵌入。

模型的第三部分:上下文無關推理:未知任務標識(context-free inference: unknown task identity)。從輸入資料的角度確定要解決的任務。超網路需要任務嵌入輸入來生成目標模型權重。在某些連續學習的應用中,由於任務標識是明確的,或者可以容易地從上下文線索中推斷,因此可以立即選擇合適的嵌入。在其他情況下,選擇合適的嵌入則不是那麼容易。

作者在論文中討論了連續學習中利用任務條件超網路的兩種不同策略。

策略一:依賴於任務的預測不確定性。神經網路模型在處理分佈外的資料方面越來越可靠。對於分類目標分佈,理想情況下為不可見資料產生平坦的高熵輸出,反之,為分佈內資料產生峰值的低熵響應。這提出了第一種簡單的任務推理方法(HNET+ENT),即給定任務標識未知的輸入模式,選擇預測不確定性最小的任務嵌入,並用輸出分佈熵量化。

策略二:當生成模型可用時,可以透過將當前任務資料與過去合成的資料混合來規避災難性遺忘。除了保護生成模型本身,合成資料還可以保護另一模型。這種策略實際上往往是連續學習中最優的解決方案。受這些成功經驗的啟發,作者探索用回放網路(replay network)來增強深度學習系統。

合成回放(Synthetic replay)是一種強大但並不完美的連續學習機制,因為生成模式容易漂移,錯誤往往會隨著時間的推移而積累和放大。作者在一系列關鍵觀察的基礎上決定:就像目標網路一樣,重放模型可以由超網路指定,並允許使用輸出正則化公式。而不是使用模型自己的回放資料。因此,在這種結合的方法中,合成重放和任務條件元建模同時起作用,避免災難性遺忘。

基準測試

作者使用MNIST、CIFAR10和CIFAR-100公共資料集對論文中的方法進行了評估。評估主要在兩個方面: (1)研究任務條件超網路在三種連續學習環境下的記憶保持能力,(2)研究順序學習任務之間的資訊傳遞。

具體的在評估實驗中,作者根據任務標識是否明確出了三種連續學習場景:CL1,任務標識明確;CL2,任務標識不明確,並不需明確推斷;CL3,任務標識可以明確推斷出來。另外作者在MNIST資料集上構建了一個全連通的網路,其中超參的設定參考了van de Ven & Tolias (2019)論文中的方法。在CIFAR實驗中選擇了ResNet-32作為目標神經網路。

van de Ven & Tolias (2019):

Gido M. van de Ven and Andreas S. Tolias. Three scenarios for continual learning. arXiv preprint arXiv:1904.07 734, 2019.

為了進一步說明論文中的方法,作者考慮了四個連續學習分類問題中的基準測試:非線性迴歸,PermutedMNIST,Split-MNIST,Split CIFAR-10/100。

非線性迴歸的結果如下:

面向超網路的連續學習:新演算法讓人工智慧不再“災難性遺忘”

注:圖a:有輸出正則化的任務條件超網路可以很容易地對遞增次數的多項式序列建模,同時能夠達到連續學習的效果。圖b:和多工直接訓練的目標網路找到的解決方案類似。圖c:循序漸進地學習會導致遺忘。

在PermutedMNIST中,作者並對輸入的影像資料的畫素進行隨機排列。發現在CL1中,任務條件超網路在長度為T=10的任務序列中表現最佳。在PermutedMNIST上任務條件超網路的表現非常好,對比來看突觸智慧(Synaptic Intelligence) ,online EWC,以及深度生成回放( deep generative replay)方法有差別,具體來說突觸智慧和DGR+distill會發生退化,online EWC不會達到非常高的精度,如下圖a所示。綜合考慮壓縮比率與任務平均測試集準確性,超網路允許的壓縮模型,即使目標網路的引數數量超過超網路模型的引數數量,精度依然保持恆定,如下圖b所示。

面向超網路的連續學習:新演算法讓人工智慧不再“災難性遺忘”

Split-MNIST作為另一個比較流行的連續學習的基準測試,在Split-MNIST中將各個數字有序配對,並形成五個二進位制分類任務,結果發現任務條件超網路整體效能表現最好。另外在split MNIST問題上任務重疊,能夠跨任務傳遞資訊,並發現該演算法收斂到可以產生同時解決舊任務和新任務的目標模型引數的超網路配置。如下圖所示

面向超網路的連續學習:新演算法讓人工智慧不再“災難性遺忘”

圖a:即使在低維度空間下仍然有著高分類效能,同時沒有發生遺忘。圖b:即使最後一個任務佔據著高效能區域,並在遠離嵌入向量的情況下退化情況仍然可接受,其效能仍然較高。

在CIFAR實驗中,作者選擇了ResNet-32作為目標神經網路,在實驗過程中,作者發現運用任務條件超網路基本完全消除了遺忘,另外還會發生前向資訊反饋,這也就是說與從初始條件單獨學習每個任務相比,來自以前任務的知識可以讓網路表現更好。

綜上,在論文中作者提出了一種新的連續學習的神經網路應用模型--任務條件超網路,該方法具有可靈活性和通用性,作為獨立的連續學習方法可以和生成式回放結合使用。該方法能夠實現較長的記憶壽命,並能將資訊傳輸到未來的任務,能夠滿足連續學習的兩個基本特性。

參考文獻:

HYPERNETWORKS:

CONTINUAL LEARNING WITH HYPERNETWORKS

https://mp.weixin.qq.com/s/hZcVRraZUe9xA63CaV54Yg





來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69946223/viewspace-2673274/,如需轉載,請註明出處,否則將追究法律責任。

相關文章