編輯 | 蘿蔔皮
通常,矩陣乘法 (MatMul) 在大型語言模型(LLM)總體計算成本中佔據主導地位。隨著 LLM 擴充套件到更大的嵌入維度和上下文長度,這方面的成本只會增加。
加州大學、LuxiTech 和蘇州大學的研究人員聲稱開發出一種新方法,透過消除過程中的矩陣乘法來更有效地執行人工智慧語言模型。這從根本上重新設計了目前由 GPU 晶片加速的神經網路操作方式。
研究人員描述瞭如何在不使用 MatMul 的情況下建立一個自定義的 27 億引數模型,效能與當前最先進的 Transformer 模型相當。
該研究以「Scalable MatMul-free Language Modeling」為題,於 2024 年 6 月 4 日釋出在 arXiv 預印平臺。
矩陣乘法是當今大多數神經網路計算任務的核心,而 GPU 特別擅長快速執行數學運算,因為它們可以並行執行大量乘法運算。
這種能力甚至讓 Nvidia 在兩週前短暫地成為了全球最有價值的公司;該公司目前佔據資料中心 GPU 市場約 98% 的份額,這些 GPU 通常用於為 ChatGPT 和 Google Gemini 等 AI 系統提供支援。
圖示:370M 中無 MatMul 的 Transformer++ 和新方法的訓練步驟損失。(來源:論文)
在最新的研究中,加州大學、LuxiTech 和蘇州大學的研究人員展示了 LLM 中可以完全消除 MatMul 操作,同時在十億引數規模下保持強勁效能。
他們透過在密集層中使用加性運算和逐元素 Hadamard 積來實現類似自注意的功能,開發了第一個可擴充套件的無 MatMul 語言模型 (Matmul-free LM)。
具體而言,研究人員利用三元權重消除了密集層中的 MatMul,類似於 BNN。為了從自注意力中移除 MatMul,研究人員最佳化了門控迴圈單元 (GRU),使其僅依賴於元素級乘積。
為了評估他們的方法,研究人員將他們的 MatMul-free LM 與複製的 Llama-2 樣式模型(他們稱之為「Transformer++」)進行了比較,涉及三種模型大小:3.7 億、13 億和 27 億引數。所有模型均在 SlimPajama 資料集上進行了預訓練,其中較大的模型分別在 1000 億個標記上進行了訓練。
不含 MatMul 的 LM 在多個基準任務上與 Llama 2 基線相比取得了具有競爭力的效能,包括回答問題、常識推理和物理理解。
實驗表明,該團隊提出的無 MatMul 模型的效能與最先進的 Transformer 模型相當,後者在推理過程中需要更多記憶體。
為了量化輕量級模型的硬體優勢,除了定製的 FPGA 加速器外,研究人員還提供了最佳化的 GPU 實現。透過在三元密集層的 GPU 實現中使用融合核心,與 GPU 上未最佳化的基線相比,訓練速度加快了 25.6%,記憶體消耗減少了高達 61.0%。
此外,透過採用低位最佳化的 CUDA 核心,當模型擴充套件到 13B 引數時,推理速度提高了 4.57 倍,記憶體使用量減少了 10 倍。
為了正確量化該架構的效率,研究人員在 FPGA 上構建了一個自定義硬體解決方案,該解決方案利用了 GPU 無法處理的輕量級操作。
研究人員演示瞭如何在 GPU 上以每秒 23.8 個 token 的速度執行 13 億個引數的模型;該方法以 13 瓦的功耗(不計算 GPU 的功耗)處理了十億引數規模的模型,超出了人類可讀的吞吐量,使 LLM 更接近類似大腦的效率。
這項工作不僅展示了 LLM 在保持有效執行的情況下可以被剝離到何種程度,而且還指出了未來加速器在處理下一代輕量級 LLM 時應該最佳化的操作型別。
不過需要明確的是,擁有 27 億個引數的 Llama-2 模型與目前市場上最好的 LLM(例如 GPT-4)相差甚遠,據估計 GPT-4 總共擁有超過 1 萬億個引數。因此,這裡還沒有在這裡討論 ChatGPT 級別的處理能力。
引數數量通常意味著模型的複雜性(以及大致上的能力)更高,研究人員一直在尋找用更少的引數實現更高階別 LLM 效能的方法。
研究人員表示,他們在實驗中觀察到的縮放規律表明,無 MatMul 的 LM 在非常大規模下的表現也可能優於傳統 LLM。
研究人員預測,他們的方法在理論上可以與標準 LLM 相媲美,並且超越其在 10²³ FLOPS 左右的規模上的效能,這大致相當於 Meta 的 Llama-3 8B 或 Llama-2 70B 等模型所需的訓練計算量。
然而,該團隊也指出他們的工作有侷限性。由於計算限制,無 MatMul 的 LM 尚未在超大規模模型(例如 1000 億多個引數)上進行測試。他們呼籲擁有更多資源的機構投資擴大規模並進一步開發這種輕量級的語言建模方法。
論文連結:https://arxiv.org/abs/2406.02528
相關報導:https://arstechnica.com/information-technology/2024/06/researchers-upend-ai-status-quo-by-eliminating-matrix-multiplication-in-llms/