DiT架構大一統:一個框架整合影像、影片、音訊和3D生成,可編輯、能試玩

机器之心發表於2024-05-13

基於 Diffusion Transformer(DiT)又迎來一大力作「Flag-DiT」,這次要將影像、影片、音訊和 3D「一網打盡」。


今年 2 月初,Sora 的釋出讓 AI 社群更加看到了基礎擴散模型的潛力。連同以往出現的 Stable Diffusion、PixArt-α 和 PixArt-Σ,這些模型在生成真實影像和影片方面取得了顯著的成功。這意味著開始了從經典 U-Net 架構到基於 Transformer 的擴散主幹架構的正規化轉變。

值得注意的是,透過這種改進的架構,Sora 和 Stable Diffusion 3 可以生成任意解析度的樣本,並表現出對 scaling 定律的嚴格遵守,即增加引數大小可以實現更好的結果。

不過,推出者們只對自家模型的設計選擇提供有限的指導,並且缺乏詳細的實現說明和公開的預訓練檢查點,限制了它們在社群使用和復刻方面的效用。並且,這些方法是針對特定任務(例如影像或影片生成任務)量身定製的,這阻礙了潛在的跨模態適應性。

為了彌補這些差距,上海 AI Lab、港中文和英偉達的研究者聯合推出了 Lumina-T2X 系列模型,透過基於流(Flow-based)的大型擴散 Transformers(Flag-DiT)打造,旨在將噪聲轉換為影像、影片、多檢視 3D 物件和基於文字描述的音訊。

其中,Lumina-T2X 系列中最大的模型包括具有 70 億引數的 Flag-DiT 和一個多模態大語言模型 SPHINX。SPHINX 是一個文字編碼器,它具有 130 億引數,能夠處理 128K tokens。

圖片

  • 論文地址:https://arxiv.org/pdf/2405.05945
  • GitHub 地址:https://github.com/Alpha-VLLM/Lumina-T2X
  • 模型下載地址:https://huggingface.co/Alpha-VLLM/Lumina-T2I/tree/main
  • 論文標題:Lumina-T2X: Transforming Text into Any Modality, Resolution, and Duration via Flow-based Large Diffusion Transformers

具體來講,基礎的文字到影像模型 Lumina-T2I 利用流匹配框架,在精心整理的高解析度真實影像文字對資料集上進行訓練,只需要使用很少的計算資源就能取得真實感非常不錯的結果。

如圖 1 所示,Lumina-T2I 可以生成任意解析度和寬高比的高質量影像,並進一步實現高階功能,包括解析度外推、高解析度編輯、構圖生成和風格一致生成,所有這些都以免訓練的方式無縫整合到框架中。

圖片

此外,為了增強跨各種模態的生成能力,Lumina-T2X 從頭開始對影片 - 文字、多檢視 - 文字和語音 - 文字對進行獨立訓練,從而可以合成影片、多檢視 3D 物件以及文字語音指示。例如,Lumina-T2V 僅用有限的資源和時間進行訓練,可以生成任何寬高比和時長的 720p 影片,顯著縮小了開源模型與 Sora 之間的差距。

我們先來一睹實現效果如何。比如生成影片:

圖片

生成單張影像:

圖片

3D 生成:

圖片

語音生成:

圖片

圖生成

圖片

風格一致性生成:

圖片

更大解析度外推:

圖片

影像編輯:

圖片

可以說,Lumina-T2X 系列模型真正實現了影像、影片、3D 和語音的「大一統」。

圖片

目前,研究者已經推出了分別使用 Flag-DiT 2B 和 Gemma 2B 作為文字編碼器的 Lumina-Next-T2I 模型,可以在 gradio 上試玩。

  • 試用地址 1:http://106.14.2.150:10021/
  • 試用地址 2:http://106.14.2.150:10022/

方法概覽

Flag-DiT 架構

Flag-DiT 是 Lumina-T2X 框架的主幹,它具有顯著的穩定性、靈活性和可擴充套件性。

首先是穩定性。Flag-DiT 建立在 DiT 之上,並結合 ViT-22B 和 LLaMa 來修改,以提高訓練穩定性。具體來說,Flag-DiT 將所有 LayerNorm 替換為 RMSNorm,使得訓練穩定性增強。

此外,Flag-DiT 在鍵查詢點積注意力計算之前結合鍵查詢歸一化(KQ-Norm)。KQ-Norm 的引入旨在透過消除注意力 logits 中極大值來防止損失發散。這種簡單的修改可以防止混合精度訓練下的發散損失,並有助於以更高的學習率進行最佳化。Flag-DiT 的詳細計算流如圖 2 所示。

圖片

其次是靈活性。DiT 僅支援具有簡單標籤條件和固定 DDPM 公式的固定解析度影像生成。為了解決這些問題,研究者首先探究為什麼 DiT 缺乏以任意解析度和比例生成樣本的靈活性。他們發現,專為視覺識別任務而設計的 APE 很難泛化到訓練之外未見過的解析度和規模。

因此,受最近展現出強大上下文外推能力的 LLM 的推動,他們用 RoPE 替換了 APE。RoPE 按照以下公式 1 和 2 以分層方式注入相對位置資訊。

圖片

最後,研究者根據經驗,使用更大的引數和更多的訓練樣本擴充套件了 Flag-DiT。具體來講,他們探索在標籤條件 ImageNet 生成基準上將引數大小從 600M 擴大到 7B。

圖片

Lumina-T2X 整體流程

如圖 3 所示,Lumina-T2X 在訓練過程中主要由四個元件組成,接下來進行一一介紹。

圖片

不同模態的逐幀編碼。在 Lumina-T2X 框架中統一不同模態的關鍵是將影像、影片、多檢視影像和語音訊譜圖視為長度為 T 的幀序列,然後利用特定模態的編碼器來將這些輸入轉換為形狀為 [H, W, T, C] 的潛在框架。

使用多種文字編碼器進行文字編碼。對於文字條件生成,研究者使用預先訓練的語言模型對文字提示進行編碼。他們結合了各種大小不一的文字編碼器,其中包括 CLIP、LLaMA、SPHINX 和 Phone 編碼器,針對各種需求和模態進行量身定製,以最佳化文字調整。

輸入和目標構建。Lumina-T2X 在流匹配中採用線性插值方案來構建輸入和目標,具體如下公式 4 和 6 所示,簡單靈活。並且,受到中間時間步對於擴散模型和流模型都至關重要的觀察啟發, 研究者在訓練期間採用時間重取樣策略從對數範數分佈中取樣時間步。

圖片

圖片

網路架構和損失。研究者使用 Flag-DiT 作為去噪主幹。給定噪聲輸入,Flag-DiT 塊透過調控機制注入新增了全域性文字嵌入的擴散時間步,並使用如下公式 9 透過零初始化注意力來進一步整合文字調整。

圖片

Lumina-T2X 系列

Lumina-T2X 系列模型包括了 Lumina-T2I、Lumina-T2V、LuminaT2MV 和 Lumina-T2Speech。對於每種模態,Lumina-T2X 都經過了針對不同場景最佳化的多配置獨立訓練,例如不同的文字編碼器、VAE 潛在空間和引數大小。具體如圖 17 所示。

圖片

Lumina-T2I 的高階應用

除了基本的文字生成影像之外,文字到影像版本的 Lumina-T2I 還支援更復雜的視覺創作,並作為基礎模型產生富有創造力的視覺效果。這包括解析度外推、風格一致性生成、高解析度影像編輯和構圖生成

與以往使用多種方法解決這些任務的策略不同,Lumina-T2I 可以透過 token 操作統一解決這些問題,如圖 4 所示。

圖片

免調整解析度外推。RoPE 的平移不變性增強了 Lumina-T2X 的解析度外推潛力,使其能夠生成域外解析度的影像。Lumina-T2X 解析度最高可以外推到 2K。

風格一致性生成。基於 Transformer 的擴散模型架構使得 Lumina-T2I 自然地適合風格一致性生成等自注意力操作應用。

圖生成。研究者只將此操作應用於 10 個注意力交叉層,以確保文字資訊被注入到不同的區域。同時保持自注意層不變,以確保最終影像的連貫、和諧。

高解析度編輯。除了高解析度生成之外,Lumina-T2I 還可以執行影像編輯,尤其是對於高解析度影像。

實驗結果

在 ImageNet 上驗證 Flag-DiT

研究者在有標籤條件的 256×256 和 512×512 ImageNet 上進行實驗,以驗證 Flag-DiT 相對於 DiT 的優勢。Large-DiT 是 Flag-DiT 的特化版本,採用了 DDPM 演算法 ,以便與原始 DiT 進行公平比較。研究者完全沿用了 DiT 的設定,但做了以下修改,包括混合精度訓練、大學習率和架構修改套件(如 QK-Norm、RoPE 和 RMSNorm)。

研究者將其與 SOTA 方法的比較,如表 2 所示,Large-DiT-7B 在不使用無分類指導(CFG)的情況下,在 FID 和 IS 分數上明顯超過了所有方法,將 FID 分數從 8.60 降至 6.09。這表明,增加擴散模型的引數可以顯著提高樣本質量,而無需依賴 CFG 等額外技巧。

圖片

研究者比較了 Flag-DiT、Large-DiT 和 SiT 在 ImageNet 條件生成上的效能,為了進行公平比較,他們將引數大小固定為 600M。如圖 5 (a) 所示,在 FID 評估的所有歷時中,Flag-DiT 的效能始終優於 Large-DiT。這表明,與標準擴散設定相比,流匹配公式可以改善影像生成。此外,與 SiT 相比,Flag-DiT 的 FID 分數較低,這表明元架構修改(包括 RMSNorm、RoPE 和 K-Q norm)不僅能穩定訓練,還能提高效能。

圖片

透過混合精度訓練提高訓練速度 。Flag-DiT 不僅能提高效能,還能提高訓練效率和穩定性。如表 4 所示。Flag-DiT 每秒可多處理 40% 的影像。

圖片

ImageNet 初始化的影響 PixArt-α 利用 ImageNet 預訓練的 DiT(學習畫素依賴性)作為後續 T2I 模型的初始化。為了驗證 ImageNet 初始化的影響,研究者比較了使用 ImageNet 初始化和從頭開始訓練的 600M 引數模型的 Lumina-T2I 速度預測損失。如圖 5 (d) 所示,從頭開始訓練的損失水平更低,收斂速度更快。此外,從零開始可以更靈活地選擇配置和架構,而不受預訓練網路的限制。表 1 所示的簡單而快速的訓練配方也是根據這一觀察結果設計的。

Lumina-T2I 的結果

圖 6 中展示了基本的文字到影像生成能力。擴散主幹架構和文字編碼器的大容量允許生成逼真的高解析度影像,並能準確理解文字,只需使用 288 個 A100 GPU 天數。

圖片

解析度外推法不僅能帶來更大比例的影像,還能帶來更高的影像質量和更強的細節。如圖 7 所示,當解析度從 1K 外推至 1.5K 時,我們可以發現到生成影像的質量和文字到影像的對齊情況都得到了顯著提升。此外,Lumina-T2I 還能進行外推,生成解析度更低的影像,如 512 解析度,從而提供了更大的靈活性。

圖片

如圖 8 所示,透過利用一個簡單的注意力共享操作,我們可以觀察到生成批次中的強一致性。得益於完全注意力模型架構,研究者獲得了與參考文獻 [58] 中相媲美的結果,而無需使用任何技巧,如自適應例項規範化(AdaIN)。此外,研究者認為,正如先前的研究所示,透過適當的反轉技術,他們可以實現零成本的風格 / 概念個性化,這是未來探索的一個有前景的方向。

圖片

如圖 9 所示,研究者演示了組合生成 。在這其中可以定義任意數量的 prompt,併為每個 prompt 分配任意區域。Lumina-T2I 成功生成了各種解析度的高質量影像,這些影像與複雜的輸入 prompt 相一致,同時又保持了整體的視覺一致性。這表明,Lumina-T2I 的設計選擇提供了一種靈活有效的方法,在生成複雜的高解析度多概念圖像方面表現出色。

圖片

研究者對高解析度影像進行風格和主題編輯。如圖 10 所示,Lumina-T2I 可以無縫修改全域性樣式或新增主題,而無需額外的訓練。此外,他們還分析了影像編輯中的啟動時間和潛在特徵歸一化等各種因素,如圖 11 所示。

圖片

圖片

如圖 13 (a) 所示,研究者視覺化了各層和各頭部的門控值,發現大多數門控值接近零,只有少部分顯示出顯著的重要性。有趣的是,最關鍵的文字調節頭部主要位於中間層,這表明這些層在文字調節中起著關鍵作用。為了鞏固這一觀察結果,研究者對低於某一閾值的門控進行了截斷,發現 80% 的門控可以在不影響影像生成質量的情況下被停用,如圖 13 (b) 所示。這一觀察表明,在取樣過程中截斷大多數交叉注意力操作的可能性,這可以大大減少推理時間。

圖片

Lumina-T2V 的結果

研究者觀察到,使用大批次的 Lumina-T2V 能夠收斂,而小批次則難以收斂。如圖 14 (a) 所示,將批次大小從 32 增加到 1024 會導致損失收斂。另一方面,與 ImageNet 實驗中的觀察相似,增加模型引數會加速影片生成收斂速度。如圖 14 (b) 所示,當引數大小從 600M 增加到 5B 時,我們能夠在相同的訓練迭代次數下一致觀察到更低的損失。

圖片

如圖 15 所示,Lumina-T2V 的第一階段能夠生成具有場景動態變化(如場景轉換)的短影片,儘管生成的影片在解析度和持續時間上有限,總 token 數最多為 32K。經過對更長持續時間和更高解析度影片的第二階段訓練後,Lumina-T2V 能夠生成 128K token 的各種解析度和持續時間的長影片。如圖 16 所示,生成的影片展示了時間上的一致性和更豐富的場景動態,表明當使用更多的計算資源和資料時,展現出有希望的擴充套件趨勢。

圖片

圖片

更多詳細內容,請閱讀原論文。

相關文章