OpenAI攻克擴散模型短板,清華校友路橙、宋颺合作最新論文

机器之心發表於2024-10-17
多項改進實現規模空前的連續時間一致性模型。

擴散模型很成功,但也有一塊重大短板:取樣速度非常慢,生成一個樣本往往需要執行成百上千步取樣。為此,研究社群已經提出了多種擴充套件蒸餾(diffusion distillation)技術,包括直接蒸餾、對抗蒸餾、漸進式蒸餾和變分分數蒸餾(VSD)。但是,這些方法也有自己的問題,包括成本高、複雜性高、多樣性有限等。

一致性模型(CM)在解決這些問題方面具有巨大的優勢。這又進一步分為離散時間 CM 和連續時間 CM。其中離散時間 CM 會引入離散化誤差,並且需要仔細排程時間步長網格,這可能會導致樣本質量不佳。而連續時間 CM 雖可避免這些問題,但也會有訓練不穩定的問題。

近日,OpenAI 的研究科學家路橙(Cheng Lu)與戰略探索團隊負責人宋颺(Yang Song)釋出了一篇研究論文,提出了一些可簡化、穩定化和擴充套件連續時間一致性模型的技術。值得一提的是,這兩位作者都是清華校友,師從朱軍教授,在擴散機率模型領域做出過代表性工作。
圖片
  • 論文標題:Simplifying, Stabilizing & Scaling Continuous-Time Consistency Models
  • 論文地址:https://arxiv.org/pdf/2410.11081v1

他們的貢獻包括:

  • TrigFlow,一個將 EDM(arXiv:2206.00364)與流匹配(Flow Matching)統一起來的公式,其能極大簡化擴充套件模型、相關的機率流 ODE 和一致性模型(CM。
  • 在此基礎上,他們分析了一致性模型訓練不穩定的根本原因,並提出了一種完整的緩解方案。他們的方法包括改進網路架構中的時間調節和自適應分組歸一化。
  • 此外,他們還重新構建了連續時間 CM 的訓練目標,其中整合了關鍵項的自適應加權和歸一化以及漸進退火,以實現穩定且可擴充套件的訓練。

簡化連續時間一致性模型

作為前提,這裡先給出離散時間和連續時間一致性模型的公式:

離散時間 CM:
圖片
連續時間 CM:
圖片
圖片
此前的一致性模型採用了 EDM 中的模型引數化和擴散過程。具體來說,一致性模型會被引數化以下形式:
圖片
其中,F 是一個神經網路,θ 是其引數;c_skip、c_out、c_in 都是固定的係數,用以確保在所有時間步驟上初始化時擴散目標的方差相等;c_noise 是對 t 的一個變換運算,以便更好地實現時間調節。

由於在 EDM 擴散過程中,方差會爆炸式增長,也就意味著 x_t = x_0 + tz_t,基於此可以推匯出下面三式:
圖片
雖然這些係數對於訓練效率很重要,但由於它們與 t 和 σ_d 之間存在複雜的算術關係,因此會使得對一致性模型的理論分析變得複雜。

為了簡化 EDM 及隨之的一致性模型,他們提出了 TrigFlow。這種擴散模型形式保留了 EDM 性質,但滿足 c_skip (t) = cos (t)、c_out (t) = sin (t)、c_in (t) ≡ 1/σ_d。

TrigFlow 是流匹配(也稱為隨機插值或整流)和 v 預測引數化的一種特例。它與之前一些研究團隊提出的三角插值非常相似,但經過修改從而納入了對資料分佈 p_d 的標準差 σ_d 的考量。

由於 TrigFlow 是流匹配的一個特例,同時滿足 EDM 原理,因此其集兩者之長,同時還讓擴散過程、擴散模型引數化、PF-ODE、擴散訓練目標和一致性模型引數化全都變得更簡單了。
圖片
讓連續時間一致性模型變得穩定

連續時間 CM 的訓練一直都高度不穩定。因此,它們的表現一直不及之前研究中的離散時間 CM。

為了解決這個問題,該團隊在 TrigFlow 框架的基礎上,引入了幾項基於理論研究的改進措施,其中重點關注的是引數化、網路架構和訓練目標。
引數化和網路架構

連續時間 CM 的訓練的關鍵是 (2) 式,其取決於正切函式圖片,而在 TrigFlow 框架下,該正切函式由以下公式給出:
圖片
其中 圖片表示 PF-ODE,其要麼在一致性蒸餾中使用預訓練的擴散模型估計得出,要麼就在一致性訓練中使用從噪聲和乾淨樣本計算得到的無偏估計器估計得出。

為了讓訓練過程穩定下來,必須確保 (6) 式中的正切函式在不同時間步驟中保持穩定。該團隊在實踐中發現 σ_dF_θ、PF-ODE 和噪聲樣本 x_t 能保持相對穩定。進一步分析後,他們發現圖片通常經過良好調節。因此不穩定的來源是時間導數 圖片,而其可被分解為:
圖片
其中,emb (・) 是指時間嵌入,通常以位置嵌入或傅立葉嵌入的形式出現在擴散模型和 CM 的相關文獻中。

該團隊穩定 (7) 中每個元素的方法包括恆等時間變換(c_noise (t) = t)、位置時間嵌入和自適應雙重歸一化。詳見原論文。

圖 4 視覺化地展示了在 CIFAR-10 上訓練 CM 時穩定時間導數的情況。研究表明,這些改進可在不損害擴散模型訓練的前提下穩定 CM 的訓練動態。
圖片
訓練目標

使用 TrigFlow 和前述的最佳化技術,(2) 式中連續時間 CM 訓練的梯度就會變為:
圖片
之後,該團隊又使用了另外一些技術來顯式地控制該梯度,以提升穩定性,其中包括正切歸一化、自適應加權、擴散微調和正切預熱。詳見原論文。

有了這些技術,離散時間和連續時間 CM 訓練的穩定性都能得到顯著改善。

該團隊在相同的設定下訓練了連續時間 CM 和離散時間 CM。如圖 5 (c) 所示,增加離散時間 CM 中的離散化步驟數 N 可提高樣本質量,原因是這樣做可減少離散化誤差來;但一旦 N 變得太大(N > 1024 之後),樣本質量就會降低,這是因為會出現數值精度問題。
圖片
相較之下,在所有 N 值上,連續時間 CM 的表現都顯著優於離散時間 CM。這能為我們提供選擇連續時間 CM 的強有力依據。

該團隊將他們的模型稱為 sCM,其中 s 代表 simple、stable、scalable,即簡單、穩定和可擴充套件。下面是 sCM 訓練的詳細虛擬碼
圖片
擴充套件連續時間一致性模型

在這部分,研究者透過在各種具有挑戰性的資料集上訓練大規模 sCM 來測試上述內容中提出的所有改進措施。

大規模模型中的正切計算

訓練大規模擴散模型的常見設定包括使用半精度(FP16)和 Flash Attention。由於訓練連續時間 CM 需要精確計算正切 圖片,我們需要提高數值精度,同時支援高效記憶注意力計算,詳情如下。

計算圖片時,需要計算 圖片,這可透過圖片與輸入向量 圖片和正切向量 圖片的雅可比向量積(JVP)高效獲得。然而根據經驗,當 t 接近 0 或 π/2 時,正切可能會在中間層溢位。為了提高數值精度,研究者認為應該重新安排正切的計算。

具體來說,由於公式 (8) 中的目標函式包含圖片,而 圖片圖片成正比,因此可以這樣計算 JVP:
圖片
圖片的 JVP,輸入為圖片,正切為圖片這種重新排列大大緩解了中間層的溢位問題,使 FP16 的訓練更加穩定。

Flash Attention 被廣泛用於大規模模型訓練中的注意力計算,既能節省 GPU 記憶體,又能加快訓練速度。然而,Flash Attention 並不計算雅可比向量積(JVP)。為了填補這一空白,研究者提出了一種類似的演算法,它能以 Flash Attention 的風格在一次前向傳遞中高效計算 softmax 自注意力及其 JVP,從而顯著減少注意力層中 JVP 計算所需的 GPU 記憶體用量。

實驗
sCM 的訓練計算。研究者在所有資料集上使用與教師擴散模型相同的批大小。sCD 每次訓練迭代的有效計算量大約是教師模型的兩倍。他們觀察到,sCD 的兩步取樣質量收斂很快,只用了不到教師模型 20% 的訓練計算量,就獲得了與教師擴散模型相當的結果。在實踐中,只需使用 sCD 進行 20k 次微調迭代,就能獲得高質量的樣本。
基準。在表 1 和表 2 中,研究者透過 FID 和函式評估次數(NFE)的基準,將本文結果與之前的方法進行了比較。首先,sCM 優於之前所有不依賴與其他網路聯合訓練的幾步式方法,與對抗訓練取得的最佳結果相當,甚至超越。值得注意的是,sCD-XXL 在 ImageNet 512×512 上的一步 FID 超過了 StyleGAN-XL 和 VAR。此外,sCD-XXL 的兩步 FID 效能優於除擴散模型外的所有生成模型,可與需要 63 個連續步驟的最佳擴散模型相媲美。其次,兩步式 sCM 模型將與教師擴散模型的 FID 差距顯著縮小到 10% 以內。此外,sCT 在較小的擴充套件上更有效,但在較大擴充套件上的方差會增大,而 sCD 在小型擴充套件和大型擴充套件上都表現出一致的效能。
圖片
圖片
Scaling 研究。如圖 6 所示,首先,隨著模型 FLOPs 的增加,sCT 和 sCD 的樣本質量都有所提高,這表明這兩種方法都能從 Scaling 中獲益。其次,與 sCD 相比,sCT 在較小解析度下的計算效率更高,但在較大解析度下的效率較低。第三,對於給定的資料集,sCD 的 Scaling 是可預測的,在不同大小的模型中,FID 的相對差異保持一致。這表明,sCD 的 FID 下降速度與教師擴散模型相同,因此,sCD 與教師擴散模型一樣具有可擴充套件性。隨著教師擴散模型的 FID 隨規模的擴大而減小,sCD 與教師模型之間 FID 的絕對差異也隨之減小。最後,FID 的相對差異隨著取樣步驟的增加而減小,兩步式 sCD 的取樣質量與教師擴散模型相當。
圖片
與 VSD 的對比。如圖 7 所示,研究者對比了 sCD、VSD、sCD 和 VSD 的組合等(透過簡單地將兩種損失相加)並觀察到,VSD 具有與擴散模型中應用大 guidance scale 類似的人工效應:它提高了保真度(表現為更高的精確度分數),同時降低了多樣性(表現為更低的召回分數)。這種效應隨著 guidance scale 的增加而變得更加明顯,最終導致嚴重的模式崩潰。相比之下,兩步式 sCD 的精確度和召回分數與教師擴散模型相當,因此 FID 分數比 VSD 更高。

更多研究細節,可參考原論文。

相關文章