線性擴散模型LiT來了,用極簡線性注意力助力擴散模型AIPC時代端側部署
机器之心發表於2025-01-31
AIxiv專欄是機器之心釋出學術、技術內容的欄目。過去數年,機器之心AIxiv專欄接收報導了2000多篇內容,覆蓋全球各大高校與企業的頂級實驗室,有效促進了學術交流與傳播。如果您有優秀的工作想要分享,歡迎投稿或者聯絡報導。投稿郵箱:liyazhou@jiqizhixin.com;zhaoyunfeng@jiqizhixin.com
王家豪,香港大學計算機系二年級博士,導師為羅平教授,研究方向為神經網路輕量化。碩士畢業於清華大學自動化系,已在 NeurIPS、CVPR 等頂級會議上發表了數篇論文。 太長不看版:香港大學聯合上海人工智慧實驗室,華為諾亞方舟實驗室提出高效擴散模型 LiT:探索了擴散模型中極簡線性注意力的架構設計和訓練策略。LiT-0.6B 可以在斷網狀態,離線部署在 Windows 膝上型電腦上,遵循使用者指令快速生成 1K 解析度逼真圖片。 圖 1:LiT 在 Windows 膝上型電腦的離線端側部署:LiT 可以在端側,斷網狀態,以完全離線的方式遵循使用者指令,快速生成 1K 解析度圖片- 論文名稱:LiT: Delving into a Simplified Linear Diffusion Transformer for Image Generation
- 論文地址:https://arxiv.org/pdf/2501.12976v1
- 專案主頁:https://techmonsterwang.github.io/LiT/
為了提高擴散模型的計算效率,一些工作使用 Sub-quadratic 計算複雜度的模組來替代二次計算複雜度的自注意力(Self-attention)機制。這其中,線性注意力的主要特點是:1) 簡潔;2) 並行化程度高。這對於大型語言模型、擴散模型這樣的大尺寸、大計算的模型而言很重要。就在幾天前,MiniMax 團隊著名的《MiniMax-01: Scaling Foundation Models with Lightning Attention》已經在大型語言模型中驗證了線性模型的有效性。而在擴散模型中,關於「線性注意力要怎麼樣設計,如何訓練好基於純線性注意力的擴散模型」的討論仍然不多。本文針對這個問題,該團隊提出了幾條「拿來即用」的解決方案,向社群讀者報告了可以如何設計和訓練你的線性擴散 Transformer(linear diffusion Transformers)。列舉如下:- 使用極簡線性注意力機制足夠擴散模型完成影像生成。除此之外,線性注意力還有一個「免費午餐」,即:使用更少的頭(head),可以在增加理論 GMACs 的同時 (給模型更多計算),不增加實際的 GPU 延遲。
- 線性擴散 Transformer 強烈建議從一個預訓練好的 Diffusion Transformer 裡做權重繼承。但是,繼承權重的時候,不要繼承自注意力中的任何權重 (Query, Key, Value, Output 的投影權重)。
- 可以使用知識蒸餾(Knowledge Distillation)加速訓練。但是,在設計 KD 策略時,我們強烈建議不但蒸餾噪聲預測結果,同樣也蒸餾方差預測結果 (這一項權重更小)。
LiT 將上述方案彙總成了 5 條指導原則,方便社群讀者拿來即用。在標準 ImageNet 基準上,LiT 只使用 DiT 20% 和 23% 的訓練迭代數,即可實現相當 FID 結果。LiT 同樣比肩基於 Mamba 和門控線性注意力的擴散模型。在文生圖任務中,LiT-0.6B 可以在斷網狀態,離線部署在 Windows 膝上型電腦上,遵循使用者指令快速生成 1K 解析度逼真圖片,助力 AIPC 時代降臨。 Diffusion Transformer 正在助力文生圖應用的商業化,展示出了極強的商業價值和潛力。但是,自注意力的二次計算複雜度也成為了 Diffusion Transformer 的一個老大難問題。因為這對於高解析度的場景,或者端側裝置的部署都不算友好。常見的 Sub-quadratic 計算複雜度的模組有 Mamba 的狀態空間模型(SSM)、門控線性注意力(GLA)、線性注意力等等。目前也有相關的工作將其用在基於類別的(class-conditional)影像生成領域 (非文生圖),比如使用了 Mamba 的 DiM、使用了 GLA 的 DiG 。但是,雖然這些工作確實實現了 Sub-quadratic 的計算複雜度,但是,這些做法也存在明顯的不足:- 其一,SSM 和 GLA 模組都依賴遞迴的狀態 (State) 變數,需要序列化迭代計算,對於並行化並不友好。
- 其二,SSM 和 GLA 模組的計算圖相對於 線性注意力 而言更加複雜,而且會引入一些算數強度 (arithmetic-intensity) 比較低的操作,比如逐元素乘法。
而線性注意力相比前兩者,如下圖 2 所示,不但設計簡單,而且很容易實現並行化。這樣的特點使得線性注意力對於高解析度極其友好。比如對於 2048px 解析度圖片,線性注意力比自注意力快約 9 倍,對於 DiT-S/2 生成所需要的 GPU 記憶體也可以從約 14GB 降低到 4GB。因此,訓練出一個效能優異的基於線性注意力的擴散模型很有價值。 圖 2:與 SSM 和 GLA 相比,線性注意力同樣實現 sub-quadratic 的計算複雜度,同時設計極其簡潔,且不依賴遞迴的狀態變數但是,對於有挑戰性的影像生成任務,怎麼快速,有效地訓練好基於線性注意力的擴散模型呢?這個問題很重要,因為一方面,儘管線性注意力在視覺識別領域已經被探索很多,可以取代自注意力,但是在影像生成中仍然是一個探索不足的問題。另一方面,從頭開始訓練擴散模型成本高昂。比如訓練 RAPHAEL 需要 60K A100 GPU days ( 中報告)。因此,針對線性擴散 Transformer 的高價效比訓練策略仍然值得探索。LiT 從架構設計和訓練策略中系統地研究了純線性注意力的擴散 Transformer 實現。LiT 是一種使用純線性注意力的 Diffusion Transformer。LiT 訓練時的成本效率很高,同時在推理過程中保持高解析度友好屬性,並且可以在 Windows 11 膝上型電腦上離線部署。在基於類別的 ImageNet 256×256 基準上面,100K 訓練步數的 LiT-S/B/L 在 FID 方面優於 400K 訓練步數的 DiT-S/B/L。對於 ImageNet 256×256 和 512×512,LiT-XL/2 在訓練步驟只有 20% 和 23% 的條件下,實現了與 DiT-XL/2 相當的 FID。在文生圖任務中,LiT-0.6B 可以在斷網狀態,離線部署在 Windows 膝上型電腦上,遵循使用者指令快速生成 1K 解析度逼真圖片。鑑於對生成任務上的線性擴散 Transformer 的探索不多,LiT 先以 DiT 為基礎,構建了一個使用線性注意力的基線模型。基線模型與 DiT 共享相同的宏觀架構,唯一的區別是將自注意力替換為 線性注意力。所有實驗均在基於類別的 ImageNet 256×256 基準上進行,使用 256 的 Batch Size 訓練了 400K 迭代次數。Guideline 1:Simplified 線性注意力對於基於 DiT 的影像生成擴散模型完全足夠。我們首先嚐試了在通用視覺基礎模型中成功驗證的常見線性注意力的架構設計,比如 ReLU 線性注意力 (使用 ReLU 啟用函式作為線性注意力的 Kernel Function)。對於效能參考,將其與 DiT 進行比較,其中任何效能差異都可以歸因於線性注意力對生成質量的影響。如圖 4 中所示。與 DiT 相比,使用 ReLU 線性注意力的 LiT-S/2 和 B/2 效能下降很大。結果表明,視覺識別中常用的線性注意力在噪聲預測任務中有改進的空間。- 簡化型線性注意力 (圖 3,相當於在 ReLU 線性注意力的基礎上加上 Depth-wise 卷積)。
- Focused 線性注意力 (使用 GELU 替換 ReLU)。
這些選擇中的每一個都保持了線性複雜度,保持了 LiT 在計算效率方面的優勢。我們使用相對較大的卷積核 (Kernel Size 5) 來確保在預測噪聲時足夠大的感受野。 圖 3:在 Simplified 線性注意力中使用更少的 heads實驗結果如圖 4 所示。加了 DWC 的模組都可以取得大幅的效能提升,我們認為這是因為模型在預測給定畫素的噪聲時關注相鄰畫素資訊。同時,我們發現 Focused Function 的有效性有限,我們將其歸因於其設計動機,以幫助線性注意聚焦於特定區域。此功能可能適合分類模型,但可能不是噪聲預測所必需的。為了簡單起見,最後使用簡化 線性注意力。Guideline 2:線上性注意力中建議使用很少的頭,可以在增加計算的同時不增加時延。直覺上似乎使用更多頭可以減少計算壓力。但相反,我們建議使用更少的頭,因為我們觀察到線性注意力存在 Free Lunch 效應,如圖 5 所示。圖 5 展示了使用線性注意力的 Small,Base,Large,XLarge 模型使用不同頭數量的延遲和 GMACs 變化。 圖 5:線性注意力中的 Free Lunch 效應:不同頭數量線性注意的延遲與理論 GMACs 比較我們使用 NVIDIA A100 GPU 生成 256×256 解析度的影像,批次大小為 8 (NVIDIA V100 GPU 出現類似現象)。結果表明,減小頭數量會導致理論 GMACs 穩定增加,實際延遲卻並沒有呈現出增加的趨勢,甚至出現下降。我們將這種現象總結為線性注意力的「免費午餐(Free Lunch)」效應。我們認為線上性注意力中使用更少的頭之後,允許模型有較高的理論計算,根據 scaling law,允許模型在生成效能上達到更高的上限。實驗結果如圖 6 所示,對於不同的模型尺度,線性注意力中使用更少的頭數 (比如,2,3,4) 優於 DiT 中的預設設定。相反,使用過多的頭(例如,S/2 的 96 或 B/2 的 192),則會嚴重阻礙生成質量。LiT 與 DiT 共享一些相同的結構,允許權重繼承自預訓練的 DiT 架構。這些權重包含豐富的與噪聲預測相關的知識,有望以成本高效的方式轉移到 LiT。因此,在這個部分我們探索把預訓練的 DiT 權重 (FFN 模組、adaLN、位置編碼和 Conditional Embedding 相關的引數) 繼承給線性 DiT,除了線性注意力部分。 圖 6:線性擴散 Transformer 的權重繼承策略Guideline 3:線性擴散 Transformer 的引數應該從一個預訓練到收斂的 DiT 初始化。我們首先預訓練 DiT-S/2 不同的訓練迭代次數:200K、300K、400K、600K 和 800K,並且在每個實驗中,分別將這些預訓練的權重載入到 LiT-S/2 中,同時線性注意力部分的引數保持隨機。然後將初始化的 LiT-S/2 在 ImageNet 上訓練 400K 迭代次數,結果如圖 6 所示。- DiT 的預訓練權重,即使只訓練了 200K 步,也起著重要作用,將 FID 從 63.24 提高到 57.84。
- 使用預訓練權重的指數移動平均 (EMA) 影響很小。
- DiT 訓練更收斂時 (800K 步),更適合作為 LiT 的初始化,即使架構沒有完全對齊。
我們認為這種現象的一種可能解釋是 Diffusion Transformer 中不同模組的功能是解耦的。儘管 DiT 和 LiT 具有不同的架構,但它們的共享元件 (例如 FFN 和 adaLN) 的行為非常相似。因此,可以遷移這些元件預訓練引數中的知識。同時,即使把 DiT 訓練到收斂並遷移共享元件的權重,也不會阻礙線性注意力部分的最佳化。 圖 7:ImageNet 256×256 上的權重繼承消融實驗結果Guideline 4:線性注意力中的 Query、Key、Value 和 Output 投影矩陣引數應該隨機初始化,不要繼承自自注意力。在 LiT 中,線性注意力中的一些權重與 DiT 的自注意力中的權重重疊,包括 Query、Key、Value 和 Output 投影矩陣。儘管計算正規化存在差異,但這些權重可以直接從 DiT 載入到 LiT 中,而不需要從頭訓練。但是,這是否可以加速其收斂性仍然是一個懸而未決的問題。我們使用經過 600K 次迭代預訓練的 DiT-S/2 進行消融實驗。探索了 5 種不同型別的載入策略,包括:- 載入 Query,Key 和 Value 投影矩陣。
結果如圖 7 所示。與沒有載入自注意力權重的基線相比,沒有一個探索的策略顯示出更好的生成效能。這種現象可歸因於計算正規化的差異。具體來說,線性注意力直接計算鍵和值矩陣的乘積,但是自注意力就不是這樣的。因此,自注意力中的 Key 和 Value 相關的權重對線性注意力的好處有限。我們建議繼承除線性注意力之外的所有預訓練引數從預訓練好的 DiT 中,因為它易於實現並且非常適合基於 Transformer 架構的擴散模型。 圖 8:混合知識蒸餾訓練線性擴散 TransformerGuideline 5:使用混合知識蒸餾訓練線性擴散 Transformer 很關鍵,不僅蒸餾噪聲預測結果,還蒸餾方差的預測結果。知識蒸餾通常採用教師網路來幫助訓練輕量級學生網路。對於擴散模型,蒸餾通常側重於減少目標模型的取樣步驟。相比之下,我們專注於在保持取樣步驟的前提下,從複雜的模型蒸餾出更簡單的模型。 圖 9:ImageNet 256×256 上的知識蒸餾實驗結果,帶有下劃線的結果表示不使用知識蒸餾到目前為止,LiT 遵循 DiT 的宏觀 / 微觀設計,但採用了高效的線性注意力。使用我們的訓練策略,LiT-S/2 顯著地提高了 FID。接下來,我們在更大的變體 (例如 B/L/XL) 和具有挑戰性的任務 (比如 T2I) 上驗證它。我們首先在 ImageNet 256×256 基準上驗證 LiT。LiT-S/2、B/2、L/2、XL/2 配置與 DiT 一致,只是線性注意力的頭分別設定為 2/3/4/4。對於所有模型變體,DWC Kernel Size 都設定為 5。我們以 256 的 Batch Size 訓練 400K 步。對於 LiT-XL/2,將訓練步數擴充套件到 1.4M 步 (只有 DiT-XL/2 7M 的 20%)。我們使用預訓練的 DiT 初始化 LiT 的引數。Lambda_1 和 lambda_2 在混合知識蒸餾中設定為 0.5 和 0.05。圖 10 和 11 比較了 LiT 和 DiT 的不同尺寸模型的結果。值得注意的是,僅 100K 訓練迭代次數訓練的 LiT 已經在各種評估指標和不同尺寸的模型中優於 400K 訓練迭代次數訓練的 DiT。使用 400K 訓練迭代次數的額外訓練,模型的效能繼續提高。儘管訓練步驟只有 DiT-XL/2 的 20%,但 LiT-XL/2 仍然取得與 DiT 相當的 FID 結果 (2.32 對 2.27)。此外,LiT 與基於 U-Net 的基線效能相當。這些結果表明,當線性注意力結合合適的最佳化策略時,可以可靠地用於影像生成應用。 圖 10:ImageNet 256×256 基準實驗結果,與基於自注意力的 DiT 和基於門控線性注意力的 DiG 的比較 圖 11:ImageNet 256×256 基準實驗結果我們繼續在 ImageNet 512×512 基準上進一步驗證了 LiT-XL/2。使用預訓練的 DiT-XL/2 作為教師模型,使用其權重初始化 LiT-XL/2。對於知識蒸餾,分別設定 Lambda_1 和 lambda_2 為 1.0 和 0.05,並且只訓練 LiT-XL/2 700K 訓練迭代次數 (是 DiT 3M 訓練迭代次數的 23%)。值得注意的是,與使用 256 的 Batch Size 的 DiT 不同,我們採用 128 的較小 Batch Size。這其實不佔便宜,因為 128 的 Batch Size 相比 256 的情況,完成 1 Epoch 需要 2 倍的訓練迭代次數。也就是說,我們 700K 的訓練迭代次數其實只等效為 256 Batch Size 下的 350K。儘管如此,使用純線性注意力的 LiT 實現了 3.69 的 FID,與 3M 步訓練的 DiT 相當,將訓練步驟減少了約 77%。此外,LiT 優於幾個強大的 Baseline。這些結果證明了我們提出的成本高效的訓練策略在高解析度資料集上的有效性。實驗結果如圖 12 所示。 圖 12:ImageNet 512×512 基準實驗結果文生圖對於擴散模型的商業應用極為重要。LiT 遵循 PixArt-α 的做法,將交叉注意力新增到 LiT-XL/2 中使其支援文字嵌入。LiT 將線性注意力的頭數設定為 2,DWC Kernel Size 設定為 5。遵循 PixArt-Σ 的做法,使用預訓練的 SDXL VAE Encoder 和 T5 編碼器 (即 Flan-T5-XXL) 分別提取影像和文字特徵。LiT 使用 PixArt-Σ 作為教師來監督其訓練,分別設定 Lambda_1 和 lambda_2 為 1.0 和 0.05。LiT 從 PixArt-Σ 繼承權重,除了自注意力的引數。隨後在內部資料集上訓練,學習率為 2e-5,僅訓練 45400 步,明顯低於 PixArt-α 的多階段訓練。圖 13 為 LiT 生成的 512px 影像取樣結果。儘管在每個 Block 中都使用了線性注意力,以及我們的成本高效的訓練策略,LiT 仍然可以產生異常逼真的影像。 圖 13:LiT 根據使用者指令生成的 512px 圖片我們還將解析度進一步增加到 1K。更多的實驗細節請參閱原論文。圖 14 是生成的結果取樣。儘管用廉價的線性注意力替換所有自注意力,但 LiT 仍然能夠以高解析度生成逼真的影像。 圖 14:LiT 根據使用者指令生成的 1K 解析度圖片我們還將 1K 解析度的 LiT-XL/2 模型部署到一臺 Windows 11 作業系統驅動的膝上型電腦上,以驗證其 On-device 的能力。考慮到膝上型電腦的 GPU 記憶體的限制,我們將文字編碼器量化為 8-bit,同時線上性注意力計算期間保持 fp16 精度。圖 1 顯示了我們的部署結果。預訓練的 LiT 可以在離線設定 (沒有網路連線) 的情況下快速生成照片逼真的 1K 解析度影像。這些結果說明 LiT 作為一種 On-device 的擴散模型的成功實現,推進邊緣裝置上的高解析度文生圖任務。下面提供了一個影片 Demo:
展示了在斷網狀態下離線使用 LiT 完成 1K 解析度文生圖任務的過程。