Facebook AI 研究院近日聯合 KU Leuven 提出了一種由著名神經科學定律——赫泊規則啟發的線上學習演算法。研究表明,這種方法可以使模型根據當前任務保留過去任務的重要特徵,靈活地適應新環境;並且可以無監督地應用於任何預訓練模型,而不受基於損失函式方法的限制。
論文:Memory Aware Synapses: Learning what (not) to forget
論文地址:https://arxiv.org/abs/1711.09601
人類可以持續不斷地學習,陳舊且不常用的知識會被新資訊覆蓋,但重要且常用的知識不會被隨意擦除。目前在人工學習系統中,終生學習(lifelong learning,LLL)主要關注在任務中積累知識和克服災難性忘卻問題(catastrophic forgetting)。在這篇論文中,我們指出,給定有限的模型容量和無限的將要學習的新資訊的時候,需要選擇對知識進行保留還是擦除。由突觸可塑性所啟發,我們提出了一種線上學習方法,基於網路對資料的啟用頻率,以無監督的方式計算神經網路引數的「重要性」。在學習了一個任務之後,每當有樣本饋送到網路中,就會基於預測輸出對引數變化的敏感度,測量網路的每個引數的重要性。當學習一個新任務的時候,會對重要引數的改變進行懲罰(即阻礙該變化)。我們證明了我們的方法的一個局域版本正好是赫泊規則(Hebb's rule)在識別神經元之間的重要連線的直接應用。我們在一系列的目標識別任務和持續學習向量的挑戰性問題上測試了我們的方法,取得了當前最佳的結果,展示了根據需求調整引數的重要性的能力。
圖 1. 研究人員提出的持續學習模式。
正如大多數終生學習論文所述,任務是按照序列學習的。在這裡我們假設,在任務學習之間,智慧體是被啟用且持續學習的。在這樣的過程中它會看到此前任務中未標記的樣本。這種資訊可以用來更新模型引數中一些重要的權重。頻繁出現的類有更大的貢獻。這樣,智慧體就可以明白哪些類別是重要的,不能被遺忘。作為結果,這些類知識在學習新任務時不會被抹去。
新研究的主要貢獻可以總結為:
- 首先,這是一種新的 LLL 方法——Memory Aware Synapses(MAS)。它基於函式逼近而不是損失函式最佳化,當學習重要性的權重的時候不需要使用標籤。從而該方法可以應用於無標籤資料,例如真實的測試環境。
- 其次,我們證明了我們的 LLL 方法和赫泊學習規律的聯絡,可以視其為我們方法的局域版本。
- 最後,我們在目標識別和事實學習(例如,<主, 謂, 賓>三元組,使用向量而不是 softmax 輸出)任務中都達到了當前最佳效能。
圖 2. 和基於損失函式最佳化的方法不同,我們的方法基於輸入-輸出的函式對引數的敏感度(梯度)。(a)在訓練第一個任務的同時,(基於損失的方法)測量損失函式對引數變化的敏感度以表示引數重要性。(b)相對的,我們在訓練完成之後,使用無標記資料計算輸出函式對引數變化的敏感度,測量引數的重要性。(c)當學習一個新任務的時候,對重要引數的改變進行懲罰。
目標識別
表 1. 目標識別的分類準確率(%)。重要性的權重Ω_ij 是在訓練資料上計算的。加粗的資料表示當前最佳。
表 2. 目標識別的分類準確率(%)。使用訓練資料和測試資料(無標籤)計算重要性的權重Ω_ij 的結果對比。
兩個任務的實驗
我們隨機地將事實分成兩部分以作為資料的兩個批次,B_1 和 B_2,並將任務設定為從 B_1 到 B_2 的遷移。
表 3. 在由 6DS 資料集隨機分成的兩個任務場景中進行事實學習的平均準確率。
表 4. 對測試條件的適應能力。分別在 B_11 和 B_12(由 B_1 分成的兩個子集)上學習重要性的權重。在由 6DS 資料集隨機分成的兩個任務場景中進行事實學習的平均準確率。
更長的任務序列
表 5. 在由 6DS 資料集分成的 4 個不相交任務場景中進行事實學習的平均準確率。
適應性測試
圖 4. 每完成 4 個任務序列中的一個之後,測試對 6DS 資料集的(關於體育運動的)子集的平均準確率。
其中 g-MAS(粉色線)學習到該子集是重要的,需要保留,並顯著地防止了對該子集的忘卻。聯合訓練方法(Joint Training,黑色虛線)作為參考,但實際上它違反了 LLL 的設定,因為它是同時訓練所有的資料。