偷懶才能更好地工作。
Llama 3.1 剛剛釋出,你是否已經嘗試了呢?就算你的個人計算機是最近的頂尖配置,執行其中最小的 8B 版本可能也依然會有明顯延遲。為了提升模型的推理效率,研究者想出了多種多樣的方法,但其中很多都會讓模型犧牲一些準確度。
近日,蘋果和 Meta AI 的一個研究團隊提出了一種新方法,可在保證準確度不明顯下降的同時,將 Llama 2 預填充階段的推理速度提升到原來的 2 倍以上,這或許能為 Llama 3.1 的加速提供一些啟發。他們把這種方法稱為 LazyLLM,即懶惰大型語言模型。
論文標題:LazyLLM: Dynamic Token Pruning for Efficient Long Context LLM Inference
論文地址:https://arxiv.org/abs/2407.14057
那麼他們是怎麼讓 LLM 偷懶的呢?要理解他們的方法,我們首先需要知道標準的基於 prompt 的 LLM 推理過程是怎樣的。簡單來說,該過程分為兩個階段:預填充和解碼,如圖 1 所示。
在預填充階段,模型計算和儲存 prompt 中每個 token 的 KV 快取,並預測首個 token。我們將預填充階段所耗費的時間稱為「首個 token 時間(TTFT)」。
預填充階段之後是解碼階段。在這個階段,模型再次使用快取的 KV 來迭代式地解碼下一個 token,直到滿足停止標準。
在預填充階段,所有 Transformer 層都會使用 prompt 中的所有 token。當 prompt 較長時,TTFT 可能很慢,因為當前最佳的基於 Transformer 的 LLM 既深又寬,並且計算注意力的成本會隨 prompt 中 token 數量而呈二次增長。舉個例子,Llama 2(7B 版本)堆疊了 32 層 Transformer,模型維度為 4096。在這種情況下,TTFT 需要的 walltime 是每個後續解碼步驟的 21 倍,在 LongBench 基準上這些時間大約佔用了總生成時間的 23%。
因此,要讓 LLM 推理高效進行,最佳化 TTFT 是非常關鍵的步驟。
儘管 LLM 推理最佳化方面是一個活躍的研究領域,但很多方法關注的重心都是提升解碼階段的推理速度。研究者很少關注 TTFT 的改進。一些基於壓縮的研究成果可透過減少 LLM 的大小隱式地提升 TTFT。
另一個研究方向是在靜態的 Transformer 架構下實現對 TTFT 的改進。對於這個研究方向,很自然會引出一個問題:在生成首個 token 時,所有 prompt token 都必不可少嗎?
圖 2 給出了在 LongBench 基準上的 LLM 分析結果。
可以看到,對於首個生成的 token,輸入 token 的注意力分數非常稀疏,這說明輸入 prompt 中的許多 token 是多餘的,就算移除也不會影響到下一 token 預測。這一觀察正是該團隊提出 LazyLLM 的基礎。
LazyLLM 的優勢包括適用範圍廣、無需訓練、效果好。圖 3 對比了標準 LLM 與 LazyLLM。
LazyLLM
圖 4 展示了 LazyLLM 的整體框架。
從完整上下文開始,LazyLLM 會逐漸對 token 進行剪枝,從而逐漸減少得到最終模型所使用的計算數量。請注意,LazyLLM 允許模型在不同的生成步驟選取不同的 token 子集,即便它們中的一些可能在之前的步驟中被剪枝了。相比於靜態剪枝(一次性對所有 token 進行剪枝),動態剪枝會在每個生成步驟對下一 token 預測進行最佳化,這有助於維持模型的效能表現。
漸進式 token 剪枝
之前也有一些研究成功使用過 token 剪枝來最佳化 LLM 推理。但是,這些方法需要積累預測前幾個 token 的完整注意力圖,以便在剪枝開始之前分析 prompt token 的重要性。也因此,它們不適合用於降低 TTFT,因為它們在預填充階段仍需要計算所有 KV 快取。
相較之下,LazyLLM 「很懶」,會從推理的第一輪迭代(預填充步驟)開始,只計算對預測下一 token 重要的 token。
在第一輪迭代中,一大關鍵難題是確定各個 token 的重要性。受之前已有研究(其中表明 token 隱藏狀態會在穿過 Transformer 層時發生演進)的啟發,該團隊的解決方案是在每個生成步驟使用逐層 token 剪枝。具體來說,他們是使用各層的注意力圖來確定輸入 token 對將要預測的 token 的重要性。
在計算了 token 的置信度分數之後,另一個難題是確定剪枝 token 的閾值。
具體來說,對於不同的層和不同的任務,該閾值可能會隨注意力分數的變化而改變。該團隊的解決思路是使用 top-k 百分位數選取策略。具體來說,如果一個 token 的置信度分數小於輸入 token 中的第 k 個百分位數,便將其剪枝掉。一旦 token 被剪枝去掉了,它就不再參與所有後續層的計算。
也就是說,後續層使用的 token 是之前層所使用 token 的子集。
後面的實驗表明,剪枝層的位置和剪枝的 token 數量不同時,也會導致效能發生變化。具體來說,對於同一 Transformer 層,隨著被剪枝去掉的 token 越來越多,模型的效能也會逐漸下降。
他們還發現,相比於早期層的剪枝,在後期層執行剪枝時會得到更好的效能,這說明後期層對 token 剪枝的敏感度更低。為了更好地平衡速度與準確度,該團隊使用瞭如圖 4 所示的漸進式剪枝法,從而在早期層保留更多 token,然後在 token 流向後期層的過程中逐漸減少 token 的數量。
Aux Cache(輔助快取)
預填充階段沒有 KV 快取,每個 token 都表示成隱藏狀態。因此,可透過移除已被剪枝 token 的隱藏狀態來實現漸進式 token 剪枝。但是,要將漸進式 token 剪枝擴充套件到後續的解碼步驟,卻並不簡單。原因是每個解碼步驟都會使用預填充階段計算的 KV 快取來計算注意力。由於 LazyLLM 是在預填充階段執行漸進式 token 剪枝,因此在某一層被剪枝的 token 的 KV 不會出現在下一層的 KV 快取中。
這裡提醒一下,LazyLLM 框架允許在每一步讓每個生成步驟從完整的輸入 token 序列中挑選一個不同的 token 子集,無論它們是否已在之前的步驟中被剪枝。舉個例子,在接下來的解碼步驟中,那些在 KV 快取中不存在的已被剪枝的 token 可能會被重新選取出來用於計算注意力。在這種情況下,模型無法檢索到這些 token 的 KV 快取。
對此,一個基於直覺的解決方案是再讓這些 token 透過該 Transformer 的起點。但是,這會導致對同一 token 的重複計算,並最終減慢整體的生成速度。
為解決這個難題,該團隊在原有的 KV 快取之外引入了另一種快取:Aux Cache(輔助快取)。
如果已被剪枝 token(如圖 4 中 T4 和 T7)的 KV 並未出現在後續層的 KV 快取中,則會由 Aux Cache 儲存它們的隱藏狀態以供後續迭代檢索。
如圖 4 所示,在每個解碼步驟,每個 Transformer 層首先會檢索過去 token 的 KV 快取(如果存在的話)。對於那些不在 KV 快取中的 token,則直接從其前一層的 Aux Cache 中檢索它們的隱藏狀態,而不必再次經過之前的層。Aux Cache 可確保每個 token 在每個 Transformer 層中最多被計算一次,還能確保 LazyLLM 最慢時也比標準 LLM 快。
實驗
該團隊在兩個大型語言模型上檢驗了這種「懶惰」新方法:Llama 2 7B 和 XGen 7B。作為對比的標準 LLM 是同樣的公開發布的預訓練檢查點模型,同時不進行任何附加訓練。
實驗基準是 LongBench,這是一個針對長內容理解的多工基準。LongBench 基準包含 16 個資料集,涉及 6 個任務,包括單文件問答、多文件問答、總結、少樣本學習、合成任務和程式碼補全。
評估指標是每種方法在 TTFT 加速與準確度權衡方面的效果和效率。
結果
表 1 給出了 LazyLLM、標準 LLM 和其它基線方法的 TTFT 加速和準確度結果。
在此表中,baseline 是指標準 LLM 推理。random token drop 是指對 token 執行隨機剪枝。static token pruning 是指在預填充階段基於前面幾個 Transformer 層的注意力方法來對輸入 token 執行一次性剪枝。Prompt Compression 就是 prompt 壓縮方法,也就是使用 LLM 去除輸入上下文中的冗餘。
從表 1 可以看到,LazyLLM 在 TTFT 加速方面全面優勝,同時準確度方面的下降基本可以忽略不計。需要指出,使用 LLM 來壓縮 prompt 需要大量計算。因此,即使 Prompt Compression 能讓推理速度更快,但其實際的 TTFT 卻比標準 LLM 還長。
對總體生成速度的影響
為了評估新方法對總體生成速度的影響,該團隊分析了計算使用的 prompt token 百分比和生成加速情況,見表 2。
可以看到,LazyLLM 計算使用的 token 的佔比總是低於 100%,這說明 LazyLLM 在生成結束時也沒有用完 prompt 中的所有 token,但理論上講該模型可以使用所有 token。這能為不同任務的整體生成過程提供額外的加速。
不同層的丟棄率
該團隊也分析了剪枝層的位置和被剪枝 token 的數量的影響。結果見圖 6。
可以看到,當在同一 Transformer 層進行剪枝時,留下的 token 越少,模型的效能越差。這也符合我們的直觀認知。此外,相比於在更前期 Transformer 層執行剪枝,在後期層進行剪枝會得到更好的效能,這說明後期層對 token 剪枝的敏感度更低。
基於這些觀察,可以說漸進式 token 剪枝的效果得到了證明。
漸進式 KV 增長
最後,該團隊也嘗試了理解使用 token 剪枝邏輯的模型的內部情況。具體來說,他們想要了解 prompt token 中的累積使用比例以及相應的不被使用的比例。這種「累積 token 使用量」可以等價地定義成每一步的 KV 快取 大小。圖 7 給出了 LazyLLM 的每個階段這些累積的 prompt token 使用量。
該結果支援這一假設:許多 token 永遠不會被模型選擇(即便理論上講模型可以使用 prompt 中的所有 token。
考慮到模型依然能維持執行任務的準確度,因此可以得出結論:模型可以有效地丟棄不影響輸出質量的 token。