ICML 2024 | 揭示非線形Transformer在上下文學習中學習和泛化的機制

机器之心發表於2024-06-28
圖片
AIxiv專欄是機器之心釋出學術、技術內容的欄目。過去數年,機器之心AIxiv專欄接收報導了2000多篇內容,覆蓋全球各大高校與企業的頂級實驗室,有效促進了學術交流與傳播。如果您有優秀的工作想要分享,歡迎投稿或者聯絡報導。投稿郵箱:liyazhou@jiqizhixin.com;zhaoyunfeng@jiqizhixin.com

本文作者李宏康,美國倫斯勒理工大學電氣、計算機與系統工程系在讀博士生,本科畢業於中國科學技術大學。研究方向包括深度學習理論,大語言模型理論,統計機器學習等等。目前已在 ICLR/ICML/Neurips 等 AI 頂會發表多篇論文。

上下文學習 (in-context learning, 簡寫為 ICL) 已經在很多 LLM 有關的應用中展現了強大的能力,但是對其理論的分析仍然比較有限。人們依然試圖理解為什麼基於 Transformer 架構的 LLM 可以展現出 ICL 的能力。

近期,一個來自美國倫斯勒理工大學和 IBM 研究院的團隊從最佳化和泛化理論的角度分析了帶有非線性注意力模組 (attention) 和多層感知機 (MLP) 的 Transformer 的 ICL 能力。他們特別從理論端證明了單層 Transformer 首先在 attention 層根據 query 選擇一些上下文示例,然後在 MLP 層根據標籤嵌入進行預測的 ICL 機制。該文章已收錄在 ICML 2024。

圖片

  • 論文題目:How Do Nonlinear Transformers Learn and Generalize in In-Context Learning?

  • 論文地址:https://arxiv.org/pdf/2402.15607

背景介紹

上下文學習 in context learning (ICL)

上下文學習 (ICL) 是一種新的學習正規化,在大語言模型 (LLM) 中非常流行。它具體是指在測試查詢 (testing query)圖片前新增 N 個測試樣本 testing examples (上下文),即測試輸入圖片和測試輸出圖片的組合,從而構成一個 testing prompt:圖片,作為模型的輸入以引導模型作出正確的推斷。這種方式不同於經典的對預訓練模型進行微調的方式,它不需要改變模型的權重,從而更加的高效。

ICL 理論工作的進展

近期的很多理論工作都是基於 [1] 所提出的研究框架,即人們可以直接使用 prompt 的格式來對 Transformer 進行訓練 (這一步也可以理解為在模擬一種簡化的 LLM 預訓練模式),從而使得模型具有 ICL 能力。已有的理論工作聚焦於模型的表達能力 (expressive power) 的角度 [2]。他們發現,人們能夠找到一個有著 “完美” 的引數的 Transformer 可以透過前向運算執行 ICL,甚至隱含地執行梯度下降等經典機器學習演算法。但是這些工作無法回答為什麼 Transformer 可以被訓練成這樣 “完美” 的,具有 ICL 能力的引數。因此,還有一些工作試圖從 Transformer 的訓練或泛化的角度理解 ICL 機制 [3,4]。不過,受制於分析 Transformer 結構的複雜性,這些工作目前止步於研究線性迴歸任務,而所考慮的模型通常會略去 Transformer 中的非線形部分。

本文從最佳化和泛化理論的角度分析了帶有非線性 attention 和 MLP 的 Transformer 的 ICL 能力和機制:

  • 基於一個簡化的分類模型,本文具體量化了資料的特徵如何影響了一層單頭 Transformer 的域內 (in-domain) 和域外 (out-of-domain, OOD) 的 ICL 泛化能力。

  • 本文進一步闡釋了 ICL 是如何透過被訓練的 Transformer 來實現了。

  • 基於被訓練的 Transformer 的特點,本文還分析了在 ICL 推斷的時候使用基於幅值的模型剪枝 (magnitude-based pruning) 的可行性。

理論部分

問題描述

本文考慮一個二分類問題,即將圖片透過一個任務圖片對映到圖片。為了解決這樣的一個問題,本文構建了 prompt 來進行學習。這裡的 prompt 被表示為:

圖片

訓練網路為一個單層單頭 Transformer:

圖片

預訓練過程是求解一個對所有訓練任務的經驗風險最小化 (empirical risk minimization)。損失函式使用的是適合二分類問題的 Hinge loss,訓練演算法是隨機梯度下降。

本文定義了兩種 ICL 泛化的情況。一個是 in-domain 的,即泛化的時候測試資料的分佈和訓練資料一樣,注意這個情況裡面測試任務不必和訓練任務一樣,即這裡已經考慮了對未見任務 (unseen task) 的泛化。另一個是 out-of-domain 的,即測試、訓練資料分佈不一樣。

本文還涉及了在 ICL 推斷的時候進行 magnitude-based pruning 的分析,這裡的剪枝方式是指對於訓練得到的中的各個神經元,根據其幅值大小,進行從小到大的刪除。

對資料和任務的構建

這一部分請參考原文的 Section 3.2,這裡只做一個概述。本文的理論分析是基於最近比較火熱的 feature learning 路線,即通常將資料假設為可分(通常是正交)的 pattern,從而推匯出基於不同 pattern 的梯度變化。本文首先定義了一組 in-domain-relevant (IDR) pattern 用於決定 in-domain 任務的分類,和一組與任務無關的 in-domain-irrelevant (IDI) pattern,這些 pattern 之間互相正交。IDR pattern 有圖片個,IDI pattern 有圖片個。一個圖片被表示為一個 IDR pattern 和一個 IDI pattern 的和。一個 in-domain 任務就被定義為基於某兩個 IDR pattern 的分類問題。

類似地,本文透過定義 out-of-domain-relevant (ODR) pattern 和 out-of-domain-irrelevant (ODI) pattern,可以刻畫 OOD 泛化時候的資料和任務。

本文對 prompt 的表示可以用下圖的例子來闡述,其中圖片是 IDR pattern,圖片是 IDI pattern。這裡在做的任務是基於 x 中的圖片做分類,如果是圖片那麼其標籤為 + 1,對應於 +q,如果是圖片那麼其標籤為 - 1,對應於 -q。α,α' 分別被定義為訓練和測試 prompt 中跟 query 的 IDR/ODR pattern 一樣的上下文示例。下圖中的例子裡面,圖片

圖片

理論結果

首先,對於 in-domain 的情況,本文先給了一個 condition 3.2 來規定訓練任務需要滿足的條件,即訓練任務需要覆蓋所有的 IDR pattern 和標籤。然後 in-domain 的結果如下:

圖片

這裡表明:1,訓練任務的數量只需要在全部任務中佔比達到滿足 condition 3.2 的小比例,我們就可以對 unseen task 實現很好的泛化;2,跟當前任務相關的 IDR pattern 在 prompt 中的比例越高,就可以以更少的訓練資料,訓練迭代次數,以及更短的 training/testing prompt 實現理想的泛化。

接下來是 out-of-domain 泛化的結果。

圖片

這裡說明,如果 ODR pattern 是 IDR pattern 的線性組合且係數和大於 1,那麼此時 OOD ICL 泛化可以達到理想的效果。這個結果給出了在 ICL 的框架下,好的 OOD 泛化所需要的訓練和測試資料之間的內在聯絡。該定理也透過 GPT-2 的實驗得到了驗證。如下圖所示,當 (12) 中的係數和圖片大於 1 的時候,OOD 分類可以達到理想的結果。與此同時,當圖片,即 prompt 中和分類任務相關的 ODR/IDR pattern 比例越高的時候,所需要的 context 長度越小。

圖片

然後,本文給出了帶有 magnitude-based pruning 的 ICL 泛化結果。

圖片

這個結果表明,首先,訓練得到的圖片中有一部分(常數比例)神經元的幅值很小,而剩下的相對比較大(公式 14)。當我們只枝剪小神經元的時候,對泛化結果基本沒有影響,而當枝剪比例增加到要剪大神經元的時候,泛化誤差會隨之顯著變大(公式 15,16)。以下實驗驗證了定理 3.7。下圖 A 中淺藍色的豎線表示訓練得到的圖片呈現出了公式 14 的結果。而對小神經元進行枝剪不會使泛化變差,這個結果符合理論。圖 B 反映出當 prompt 中和任務相關的上下文越多的時候,我們可以允許更大的枝剪比例以達到相同的泛化效能。

圖片

ICL 機制

透過對預訓練過程的刻畫,本文得到了單層單頭非線性 Transformer 做 ICL 的內在機制,這一部分在原文的 Section 4。該過程可以用下圖表示。

圖片

簡而言之,attention 層會選擇和 query 的 ODR/IDR pattern 一樣的上下文,賦予它們幾乎全部 attention 權重,然後 MLP 層會重點根據 attention 層輸出中的標籤嵌入來作出最後的分類。

總結

本文講解了在 ICL 當中,非線性 Transformer 的訓練機制,以及對於新任務和分佈偏移資料的泛化能力。理論結果對於設計 prompt 選擇演算法和 LLM 剪枝演算法有一定實際意義。

參考文獻

[1] Garg, et al., Neurips 2022. "What can transformers learn in-context? a case study of simple function classes."

[2] Von Oswald et al., ICML 2023. "Transformers learn in-context by gradient descent."

[3] Zhang et al., JMLR 2024. "Trained transformers learn linear models in-context."

[4] Huang et al., ICML 2024. "In-context convergence of transformers."

相關文章