單GPU訓練一天,Transformer在100位數字加法上就達能到99%準確率

机器之心發表於2024-06-03

乘法和排序也有效。

自 2017 年被提出以來,Transformer 已成為 AI 大模型的主流架構,一直穩站 C 位。

但所有研究者都不得不承認的是,Transformer 在算數任務中表現非常糟糕,尤其是加法,這一缺陷在很大程度上源於 Transformer 無法跟蹤大範圍數字中每個數字的確切位置。

為了解決這個問題,來自馬里蘭大學、CMU 等機構的研究者向這一問題發起了挑戰,他們透過在每個數字中新增一個嵌入來解決這個問題,該嵌入編碼數字相對於開頭的位置。該研究發現,只用一天時間在單個 GPU 上訓練 20 位數字,就可以達到最新的效能水平,100 位數字加法問題高達 99% 的準確率。

圖片

論文地址:https://arxiv.org/pdf/2405.17399

專案地址:https://github.com/mcleish7/arithmetic

標題:Transformers Can Do Arithmetic with the Right Embeddings

具體而言,研究者建議對資料表示進行一個簡單的修改,就能解決這個缺點。他們提出了 Abacus 嵌入用於編碼每個數字符號 token 範圍內的位置。將 Abacus 嵌入與標準位置嵌入結合使用後,該研究觀察到 Transformer 在算數任務上的準確率有顯著提高,以至於最多隻訓練了 20 位數運算元的模型可以泛化到 120 位數運算元的問題。這一數字代表了 6 倍的 SOTA 泛化因子,而以前的最先進的泛化因子也只有 2.5 倍。據瞭解,這是迄今為止被證明的最長的學習加法序列。

此外,本文還研究了幾種其他方法來改善 transformer 在算術和泛化方面的效能,他們發現結合輸入注入(input injection),即在輸入層和每個解碼器層之間插入跳躍連線,可以在 Abacus 嵌入基線上減少 50% 的泛化誤差。本文還發現,與嵌入結合使用的 looped transformer 架構可以在加法問題上實現幾乎完美的泛化。

本文的貢獻可以總結如下:

  • 本文提出了一種新的位置嵌入,稱為 Abacus 嵌入,以更好地捕獲每個數字的重要性,從而實現近乎完美的分佈內泛化;

  • 研究表明,當將 Abacus 嵌入與輸入注入和 looped transformer 相結合時,效能會進一步提高,分佈外準確率從 92.9% 提高到 99.1%,與單獨使用標準架構的嵌入相比,誤差降低了 87%;

  • 研究者將這些發現擴充套件到更復雜的問題,包括乘法和排序,在這些領域也展現出了長度泛化。

實現加法的長度泛化

作者研究了一系列方法,旨在提高從頭開始訓練的語言模型在算術能力上的表現。他們主要關注兩個假設:1)數字內各個位數的位置資訊正在丟失;2)迴圈可以提高 Transformer 架構在多步算術推理問題上的推理能力。在詳細描述每項改進之前,作者簡要討論了訓練和評估設定。

實驗設定

作者訓練了僅包含解碼器的因果語言模型來解決加法問題。

他們考慮了兩種標準 transformer 架構。首先,他們使用一個標準的自迴歸 transformer 模型,多個解碼器層以前饋方式堆疊。其次,他們透過輸入注入(input injection)增強了這一標準 transformer 模型,即把嵌入的輸入新增到每個解碼器層的輸入中。作者在圖 20 中直觀地描述了這些架構。

圖片

Abacus 嵌入幫助對齊數字

透過之前的研究和初步實驗,作者發現,即使輸入的數字是先顯示最不重要的數字,訓練資料是分層的、豐富的(幾百萬個例子),標準 transformer 也很難學習多位數加法。他們還觀察到,人類在進行長加法運算時,會先將數位相同的數字排列成列。因此,作者的第一個假設是,對於 transformer 來說,每個數字的數位並不容易表示,而且這個子問題比實際加法本身帶來的障礙更大。

為了解決 transformer 在表示位置資訊方面的侷限性,作者設計了一種特殊的位置嵌入,它可以編碼每個數字相對於當前數字起始位置的位置。作者將其稱之為 Abacus 嵌入。他們將相同的位置嵌入應用於所有具有相同數位的數字,從而提供一個顯式的訊號,供模型用於對齊數字,如圖 2 所示。

圖片

Abacus 嵌入解決加法問題

對於標準 transformer 架構,Abacus 嵌入可將泛化效能提高到 100 位及以上。在圖 3(左)中,作者強調了 Abacus 嵌入與標準 transformer 架構和嵌入相比,在進行加法運算時所具有的比較優勢,取三種模型在所有情況下的平均準確度。

圖片

圖 1 還顯示了使用 FIRE 和 Abacus 訓練的標準 transformer 模型的準確度結果,這些模型經過了域內 (ID) 和域外 (OOD) 測試。圖片

Transformer 中的迴圈提高了效能

在解決位置嵌入問題後,接下來作者探討了迴圈架構能否進一步提高 transformer 執行多位數加法的能力。他們使用「迴圈塊(recurrent block)」一詞來指一組具有不同權重的解碼器層,而「迴圈(recurrence)」則指迴圈塊的重複次數。作者使用有效深度(effective depth)一詞來指 transformer 中使用的層數,無論其權重是否唯一。除非另有說明,否則他們使用的是最大迴圈架構,即只迴圈一個唯一層來達到有效深度。他們還採用了輸入注入、 殘差連線的方式,將輸入的副本傳播到網路中的每一層。

迴圈的優勢

在圖 3(右)中,作者比較了使用 FIRE 和 NoPE 嵌入對運算元多達 40 位的加法進行訓練的所有架構變體。儘管引數數量僅相當於其他模型的 1/10,但可以看到,looped transformer(迴圈的、有輸入注入和漸進損失)在使用任何一種位置嵌入時都取得了最佳的分佈外效能。在圖 8 中,作者展示了這一結果在多種訓練資料規模下的穩健性。

圖片

對於迴圈模型,可以選擇在訓練時改變每次前向傳遞的迴圈次數。這往往會提高模型測試時對較難任務的泛化能力,這也被稱為漸進損失計算(progressive loss computation)。這個損失函式是兩個前向傳遞的損失值的凸組合,一個使用字面上的迴圈數(1 × 16 模型為 16),另一個使用隨機的較小迴圈數。

接下來,作者探討了在保持有效深度固定的同時改變迴圈塊大小的效果。他們將迴圈塊中的層數減半,迴圈次數增加一倍,從塊中有 16 層、迴圈次數只有一次(16 × 1,即標準 transformer)的模型,過渡到塊中只有一層、迴圈次數有 16 次(1 × 16)的模型。

透過圖 4 分析這些結果,作者發現在某些情況下,結合迴圈和 Abacus 嵌入可以進一步提高效能。具體來說,在 OOD 問題上,有兩個迴圈的模型(8 × 2)產生的誤差是純非迴圈模型(16 × 1)的一半,而在 100 + 的 OOD 問題上,其準確率也有所提高。

最後,在附錄 A.7.3 中,作者改變了模型的有效深度,以分析引數數量對這項任務的影響,包括 Abacus、FIRE 和 NoPE 嵌入。雖然圖 4 中的實驗是對不同深度的公平比較,但純粹的標準 transformer 模型比相應的迴圈模型擁有更多的引數。在附錄的表 3 中,作者記錄了最接近百萬的引數量。

圖片

圖片

實驗

研究者不僅對加法問題進行了探討,還對乘法和排序進行了研究。

整數乘法

圖 5 展示了 Abacus 嵌入模型在 15 位數乘法的分佈內準確率超過了之前的工作,且不需要用零將每個運算元填充到相同長度。特別地,該研究強調,與僅使用 FIRE 的基線相比,將 Abacus 嵌入與 FIRE 相結合也提高了分佈問題中最難的分佈準確率 (右下)。

圖片

陣列排序

表 1 展示了使用不同嵌入 ——FIRE、Abacus 及其組合 —— 訓練的標準 transformer(八層)的效能。結果顯示,組合嵌入方法增強了模型的泛化能力。

圖片

如表 2 所示,研究者觀察到在將 Abacus+FIRE 嵌入組合與不同的模型架構(有效深度為 8)配對時,結果表現出混合性。

圖片

Abacus 和相關嵌入

圖 6 展示了將 Abacus 嵌入整合到更通用系統中的真正潛力,顯示出 Abacus 嵌入與 FIRE 結合可以解鎖遠超 FIRE 嵌入解決問題的能力。

圖片

更多研究細節,請參考原論文。

相關文章