MIT三人團隊:用Transformer解決經驗貝葉斯問題,比經典方法快100倍

机器之心發表於2025-02-28
Transformer 很成功,更一般而言,我們甚至可以將(僅編碼器)Transformer 視為學習可交換資料的通用引擎。由於大多數經典的統計學任務都是基於獨立同分布(iid)採用假設構建的,因此很自然可以嘗試將 Transformer 用於它們。

針對經典統計問題訓練 Transformer 的好處有兩個:
  • 可以得到更好的估計器;

  • 可以在一個有比 NLP 更加容易和更好理解的統計結構的領域中闡釋 Transformer 的工作機制。

近日,MIT 的三位研究者 Anzo Teh、Mark Jabbour 和 Yury Polyanskiy 宣稱找到了一個可以滿足這種需求 「可能存在的最簡單的這類統計任務」,即 empirical Bayes (EB) mean estimation(經驗貝葉斯均值估計)。
image.png
  • 論文標題:Solving Empirical Bayes via Transformers

  • 論文地址:https://arxiv.org/pdf/2502.09844

該團隊表示:「我們認為 Transformer 適用於 EB,因為 EB 估計器會自然表現出收縮效應(即讓均值估計偏向先驗的最近模式),而 Transformer 也是如此,注意力機制會傾向於關注聚類 token。」對注意力機制的相關研究可參閱論文《The emergence of clusters in self-attention dynamics》。

此外,該團隊還發現,EB 均值估計問題具有置換不變性,無需位置編碼。

另一方面,人們非常需要這一問題的估計器,但麻煩的是最好的經典估計器(非引數最大似然 / NPMLE)也存在收斂速度緩慢的問題。

MIT 這個三人團隊的研究表明 Transformer 不僅效能表現勝過 NPMLE,同時還能以其近 100 倍的速度執行!

總之,本文證明了即使對於經典的統計問題,Transformer 也提供了一種優秀的替代方案(在執行時間和效能方面)。對於簡單的 1D 泊松 - EB 任務,本文還發現,即使是引數規模非常小的 Transformer(< 10 萬引數)也能表現出色。

定義 EB 任務

泊松 - EB 任務:透過一個兩步式過程以獨立同分布(iid)方式生成 n 個樣本 X_1, . . . , X_n.

第一步,從某個位於實數域 ℝ 的未知先驗 π 取樣 θ_1, . . . , θ_n。這裡的 π 的作用是作為一個未曾見過的(非引數)隱變數,並且對其不做任何假設(設定沒有連續性和平滑性假設)。

第二步,給定 θ_i,透過 X_i ∼ Poi (θ_i) 以 iid 方式有條件地對 X_i 進行取樣。

這裡的目標是根據看到的 X_1, . . . , X_n,透過image.png估計 θ_1, . . . , θ_n,以最小化期望的均方誤差(MSE)image.png。如果 π 是已知的,則這個最小化該 MSE 的貝葉斯估計器便是 θ 的後驗均值,其形式如下:
image.png
其中 圖片是 x 的後驗密度。由於 π 是未知的,於是估計器 π 只能近似 圖片。這裡該團隊的做法是將估計器的質量量化為後悔值,定義成了圖片多於圖片的 MSE:
image.png
透過 Transformer 求解泊松 - EB

簡單來說,該團隊求解泊松 - EB 的方式如下:首先,生成合成資料並使用這些資料訓練 Transformer;然後,凍結它們的權重並提供要估計的新資料。

該團隊表示,這應該是首個使用神經網路模型來估計經驗貝葉斯的研究工作。

理解 Transformer 是如何工作的

論文第四章試圖解釋 Transformer 是如何工作的,並從兩個角度來實現這一目標。首先,他們建立了關於 Transformer 在解決經驗貝葉斯任務中的表達能力的理論結果。其次,他們使用線性探針來研究 Transformer 的預測機制。

本文從 clipped Robbins 估計器開始,其定義如下:
image.png
得出:transformer 可以學習到任意精度的 clipped Robbins 估計器。即:
image.png
類似地,本文證明了 transformer 還可以近似 NPMLE。即:
image.png
完整的證明過程在附錄 B 中,論文正文只提供了一個大致的概述。

接下來,研究者探討了 Transformer 模型是如何學習的。他們透過線性探針(linear probe)技術來研究 Transformer 學習機制。

這項研究的目的是要了解 Transformer 模型是否像 Robbins 估計或 NPMLE 那樣工作。圖 1 中的結果顯示,Transformer 模型不僅僅是學習這些特徵,而是在學習貝葉斯估計器圖片是什麼。
image.png
總結而言,本章證明了 Transformer 可以近似 Robbins 估計器和 NPMLE(非引數最大似然估計器)。

此外,本文還使用線性探針(linear probes)來證明,經過預訓練的 Transformer 的工作方式與上述兩種估計器不同。

合成資料實驗與真實資料實驗

表 1 為模型引數設定,本文選取了兩個模型,並根據層數將它們命名為 T18 和 T24,兩個模型都大約有 25.6k 個引數。此外,本文還定義了 T18r 和 T24r 兩個模型。
image.png
在這個實驗中,本文評估了 Transformer 適應不同序列長度的能力。圖 2 報告了 4096 個先驗的平均後悔值。
image.png
圖 6 顯示 transformer 的執行時間與 ERM 的執行時間相當。
image.png
合成實驗的一個重要意義在於,Transformer 展示了長度泛化能力:即使在未見過的先驗分佈上,當測試序列長度達到訓練長度的 4 倍時,它們仍能實現更低的後悔值。這一點尤為重要,因為多項研究表明 Transformer 在長度泛化方面的表現參差不齊 [ZAC+24, WJW+24, KPNR+24, AWA+22]。

最後,本文還在真實資料集上對這些 Transformer 模型進行了評估,以完成類似的預測任務,結果表明它們通常優於經典基線方法,並且在速度方面大幅領先。
image.png
從表 3 可以看出,在大多數資料集中,Transformer 比傳統方法有顯著的改進。
image.png
總之,本文證明了 Transformer 能夠透過上下文學習(in-context learning)掌握 EB - 泊松問題。實驗過程中,作者展示了隨著序列長度的增加,Transformer 能夠實現後悔值的下降。在真實資料集上,本文證明了這些預訓練的 Transformer 在大多數情況下能夠超越經典基線方法。

相關文章