COLING24|自適應剪枝讓多模態大模型加速2-3倍,哈工大等推出SmartTrim

机器之心發表於2024-03-18

基於 Transformer 結構的視覺語言大模型(VLM)在各種下游的視覺語言任務上取得了巨大成功,但由於其較長的輸入序列和較多的引數,導致其相應的計算開銷地提升,阻礙了在實際環境中進一步部署。為了追求更為高效的推理速度,前人提出了一些針對 VLM 的加速方法,包括剪枝和蒸餾等,但是現有的這些方法大都採用靜態架構,其針對不同輸入例項採用同樣的計算圖進行推理,忽略了不同例項之間具有不同計算複雜性的事實:針對複雜的跨模態互動例項,自然需要更多計算才能完全理解影像和相關問題的複雜細節;相反,簡單的例項則可以用更少的計算量解決。這也導致較高加速比下的 VLM 的效能嚴重下降。

為了解決上述這些問題,哈工大聯合度小滿推出針對多模態模型的自適應剪枝演算法 SmartTrim,論文已被自然語言處理頂級會議 COLING 24 接收。

圖片

前期探究和研究動機

本文首先針對 VLM 中每一層的 token 表示和 attention head 的冗餘情況進行分析,如下圖所示。我們有了以下發現:(1)無論是哪種模態的 token 或者 head,層內相似性始終很高,說明模型是存在顯著冗餘。(2)Token 的冗餘度隨著深度而逐漸增加。(3)不同例項之間的冗餘程度差異較大,進一步說明依賴於輸入的自適應剪枝對於 VLM 加速的重要性。

圖片

在基於 VQA 微調的 METER 的跨模態編碼器中,層內不同 token(上)和 attention head(下)表示的相似性。

方法介紹

基於上述發現,本文提出針對 VLM 的自適應剪枝框架:SmartTrim,從 token 和 attention head 兩方面同時對模型冗餘部分進行剪枝

圖片

SmartTrim 框架結構圖

跨模態感知的 Token 修剪器:

文字和影像各自的 Token 序列首先經過各自編碼器進行編碼,對於得到的序列表示,經過基於 MLP 結構的跨模態感知 Token 修剪器識別對於當前層不重要的 Token:在識別過程中模型不僅考慮 token 在當前模態序列的重要性,同時還要引入其在跨模態互動中的重要性。最終 token 的重要性分數轉化成一個 0/1 的二值 mask 用來去除冗餘 token。

模態自適應的注意力頭修剪器:

VLM 分別透過 MSA(multi-head self-attention module) 和 MCA (multi-head cross-attention module)捕獲模態內和模態間互動。正如前文分析,注意力部分計算開銷根據輸入的複雜性而變化,導致注意力模組出現的冗餘會產生較大的開銷。為此,我們將模態自適應注意力頭修剪器整合到注意力模組中。該修剪器用以衡量各個注意力頭的顯著性,根據此對冗餘的注意力頭做修剪。

模型訓練

在模型的訓練過程中,我們在最佳化任務相關的訓練目標的同時,還引入了計算開銷相關的訓練目標圖片,讓模型在訓練過程中對效能和效率進行權衡。針對上述修剪器生成的二值 mask(M)在訓練中不可導的問題,我們採用了基於重引數化的技巧從而進行端到端的訓練:

圖片

自蒸餾與課程訓練策略:

我們還引入一種自蒸餾的訓練策略來提高透過自適應剪枝得到的小模型:透過對齊剪枝後的小模型和全容量模型之間輸出,使得剪枝模型的輸出與全容量模型更為一致,進一步提高小模型的能力。另外我們利用課程學習的訓練方式指導模型的訓練,使模型稀疏度逐步減低到目標比例,從而保證了最佳化過程的穩定性。

最終的模型訓練目標為:

圖片

實驗結果

我們基於 METER 和 BLIP 這兩個 VLM 作為原始模型並在一系列下游 VL 任務上評估 SmartTrim 以及其他方法的效能和效率,如下表所示:我們的方法將原始模型加速了 2-3 倍,同時效能下降最小。

圖片

具有不同加速比下的 VLM 加速方法結果。

與前人方法相比,SmartTrim 不需要額外的預訓練,而且還透過 token 和 head 兩個方面提供了更細粒度地控制模型的計算開銷,以更好地探索效率與效能之間的權衡,下面的帕累託圖顯示我們的方法在 1.5x 的加速比下甚至相比原始模型效能有所提升,而在高加速比下的相比其他加速方法具有顯著優勢。

圖片

不同 VLM 加速方法在 NLVR2 上的效率與效能權衡的帕累託前沿。

我們進一步展示了一些隨著深度增加 SmartTrim 逐步裁剪不同模態的冗餘 token 的例子:

圖片

Token 的逐步裁剪修剪過程。

上圖 (a)-(c) 是由我們提出的跨模態感知 Token 修剪器獲得的,可以看到針對不同的問題我們的修剪器網路可以合適地選擇更為相關的 patch。(d) 為去掉跨模態資訊指導的基線模型地輸出,我們也可以觀察到其只保留了圖片的主體部分但與問題並不相關的 patch token,並最終產生錯誤的答案。

我們還統計了在 vqa 資料的測試集上我們的 SmartTrim 為不同例項分配的計算量情況,如下圖所示。可以發現 SmartTrim 可以自適應地根據跨模態互動的複雜性分配不同的計算開銷,為簡單例項(圖左)分配更少的計算,為困難例項(圖右)分配更多計算。

圖片

VQA 上 SmartTrim 的 FLOPs 直方圖。

更多詳細內容可以參考論文原文。論文提出的方法未來將結合到度小滿軒轅大模型中,大模型專案地址:https://github.com/Duxiaoman-DI/XuanYuan,歡迎大家訪問!

相關文章