只用1890美元、3700 萬張影像,就能訓練一個還不錯的擴散模型。
現階段,視覺生成模型擅長建立逼真的視覺內容,然而從頭開始訓練這些模型的成本和工作量仍然很高。比如 Stable Diffusion 2.1 花費了 200000 個 A100 GPU 小時。即使研究者使用最先進的方法,也需要在 8×H100 GPU 上訓練一個多月的時間。
此外,訓練大模型也對資料集提出了挑戰,這些資料基本以億為單位,同樣給訓練模型帶來挑戰。
高昂的訓練成本和對資料集的要求為大規模擴散模型的開發造成了難以逾越的障礙。
現在,來自 Sony AI 等機構的研究者僅僅花了 1890 美元,就訓練了一個不錯的擴散模型, 具有 11.6 億引數的稀疏 transformer。
論文地址:https://arxiv.org/pdf/2407.15811
論文標題:Stretching Each Dollar: Diffusion Training from Scratch on a Micro-Budget
專案(即將釋出):https://github.com/SonyResearch/micro_diffusion
具體而言,在這項工作中,作者透過開發一種低成本端到端的 pipeline 用於文字到影像擴散模型,使得訓練成本比 SOTA 模型降低了一個數量級還多,同時還不需要訪問數十億張訓練影像或專有資料集。
作者考慮了基於視覺 transformer 的潛在擴散模型進行文字到影像生成,主要原因是這種方式設計簡單,並且應用廣泛。為了降低計算成本,作者利用了 transformer 計算開銷與輸入序列大小(即每張影像的 patch 數量)的強依賴關係。
本文的主要目標是在訓練過程中減少 transformer 處理每張影像的有效 patch 數。透過在 transformer 的輸入層隨機掩蔽(mask)掉部分 token,可以輕鬆實現這一目標。
然而,現有的掩蔽方法無法在不大幅降低效能的情況下將掩蔽率擴充套件到 50% 以上,特別是在高掩蔽率下,很大一部分輸入 patch 完全不會被擴散 transformer 觀察到。
為了減輕掩蔽造成的效能大幅下降,作者提出了一種延遲掩蔽(deferred masking)策略,其中所有 patch 都由輕量級 patch 混合器(patch-mixer)進行預處理,然後再傳輸到擴散 transformer。Patch 混合器包含擴散 transformer 中引數數量的一小部分。
與 naive 掩蔽方法相比,在 patch mixing 處理之後進行掩蔽允許未掩蔽的 patch 保留有關整個影像的語義資訊,並能夠在非常高的掩蔽率下可靠地訓練擴散 transformer,同時與現有的最先進掩蔽相比不會產生額外的計算成本。
作者還證明了在相同的計算預算下,延遲掩蔽策略比縮小模型規模(即減小模型大小)實現了更好的效能。最後,作者結合 Transformer 架構的最新進展,例如逐層縮放、使用 MoE 的稀疏 Transformer,以提高大規模訓練的效能。
作者提出的低成本訓練 pipeline 減少了實驗開銷。除了使用真實影像,作者還考慮在訓練資料集中組合其他合成影像。組合資料集僅包含 3700 萬張影像,比大多數現有的大型模型所需的資料量少得多。
在這個組合資料集上,作者以 1890 美元的成本訓練了一個 11.6 億引數的稀疏 transformer,並在 COCO 資料集上的零樣本生成中實現了 12.7 FID。
值得注意的是,本文訓練的模型實現了具有競爭力的 FID 和高質量生成,同時成本僅為 stable diffusion 模型的 1/118 ,是目前最先進的方法(成本為 28,400 美元)的 1/15。
方法介紹
為了大幅降低計算成本,patch 掩蔽要求在輸入主幹 transformer 之前丟棄大部分輸入 patch,從而使 transformer 無法獲得被掩蔽 patch 的資訊。高掩蔽率(例如 75% 的掩蔽率)會顯著降低 transformer 的整體效能。即使使用 MaskDiT,也只能觀察到它比 naive 掩蔽有微弱的改善,因為這種方法也會在輸入層本身丟棄大部分影像 patch。
延遲掩蔽,保留所有 patch 的語義資訊
由於高掩蔽率會去除影像中大部分有價值的學習訊號,作者不禁要問,是否有必要在輸入層進行掩蔽?只要計算成本不變,這就只是一種設計選擇,而不是根本限制。事實上,作者發現了一種明顯更好的掩蔽策略,其成本與現有的 MaskDiT 方法幾乎相同。由於 patch 來自擴散 Transformer 中的非重疊影像區域,每個 patch 嵌入都不會嵌入影像中其他 patch 的任何資訊。因此,作者的目標是在掩蔽之前對 patch 嵌入進行預處理,使未被掩蔽的 patch 能夠嵌入整個影像的資訊。他們將預處理模組稱為 patch-mixer。
使用 patch-mixer 訓練擴散 transformer
作者認為,patch-mixer 是任何一種能夠融合單個 patch 嵌入的神經架構。在 transformer 模型中,這一目標自然可以透過注意力層和前饋層的組合來實現。因此,作者使用一個僅由幾個層組成的輕量級 transformer 作為 patch-mixer。輸入序列 token 經 patch-mixer 處理後,他們將對其進行掩蔽(圖 2e)。
圖 2:壓縮 patch 序列以降低計算成本。由於擴散 transformer 的訓練成本與序列大小(即 patch 數量)成正比,因此最好能在不降低效能的情況下縮減序列大小。這可以透過以下方法實現:b) 使用更大的 patch;c) 隨機簡單(naive)掩蔽一部分 patch;或者 d) 使用 MaskDiT,該方法結合了 naive 掩蔽和額外的自動編碼目標。作者發現這三種方法都會導致影像生成效能顯著下降,尤其是在高掩蔽率的情況下。為了緩解這一問題,他們提出了一種直接的延遲掩蔽策略,即在 patch-mixer 處理完 patch 後再對其進行掩蔽。除了使用 patch-mixer 之外,他們的方法在所有方面都類似於 naive 掩蔽。與 MaskDiT 相比,他們的方法無需最佳化任何替代目標,計算成本幾乎相同。
假定掩碼為二進位制掩碼 m,作者使用以下損失函式來訓練模型:
其中,M_ϕ 是 patch-mixer 模型,F_θ 是主幹 transformer。請注意,與 MaskDiT 相比,本文提出的方法還簡化了整體設計,不需要額外的損失函式,也不需要在訓練過程中在兩個損失之間進行相應的超引數調優。在推理過程中,該方法不掩蔽任何 patch。
未掩蔽微調
由於極高的掩蔽率會大大降低擴散模型學習影像全域性結構的能力,並在序列大小上引入訓練 - 測試分佈偏移,因此作者考慮在掩蔽預訓練後進行少量的未掩蔽微調。微調還可以減輕由於使用 patch 掩蔽而產生的任何生成瑕疵。因此,在以前的工作中,恢復因掩蔽而急劇下降的效能至關重要,尤其是在取樣中使用無分類器引導時。然而,作者認為這並不是完全必要的,因為即使在掩蔽預訓練中,他們的方法也能達到與基線未掩蔽預訓練相當的效能。作者只在大規模訓練中使用這種方法,以減輕由於高度 patch 掩蔽而產生的任何未知 - 未知生成瑕疵。
利用 MoE 和 layer-wise scaling 改進主幹 transformer 架構
作者還利用 transformer 架構設計方面的創新,在計算限制條件下提高了模型的效能。
他們使用混合專家層,因為它們在不顯著增加訓練成本的情況下增加了模型的引數和表現力。他們使用基於專家選擇路由的簡化 MoE 層,每個專家決定路由給它的 token,因為它不需要任何額外的輔助損失函式來平衡專家間的負載。他們還考慮了 layer-wise scaling,該方法最近被證明在大型語言模型中優於典型 transformer。該方法線性增加 transformer 塊的寬度,即注意力層和前饋層的隱藏層維度。因此,網路中較深的層比較早的層被分配了更多的引數。作者認為,由於視覺模型中的較深層往往能學習到更復雜的特徵,因此在較深層使用更高的引數會帶來更好的效能。作者在圖 3 中描述了他們提出的擴散 Transformer 的整體架構。
圖 3:本文提出的擴散 transformer 的整體架構。作者在骨幹 transformer 模型中加入了一個輕量級的 patch-mixer,它可以在輸入影像中的所有 patch 被掩蔽之前對其進行處理。根據當前的研究成果,作者使用注意力層處理 caption 嵌入,然後再將其用於調節。他們使用正弦嵌入來表示時間步長。他們的模型只對未掩蔽的 patch 進行去噪處理,因此只對這些 patch 計算擴散損失(論文中的公式 3)。他們對主幹 transformer 進行了修改,在單個層上使用了 layer-wise scaling,並在交替 transformer 塊中使用了混合專家層。
實驗
實驗採用擴散 Transformer(DiT)兩個變體 DiT-Tiny/2 和 DiT-Xl/2。
如圖 4 所示,延遲掩蔽方法在多個指標中都實現了更好的效能。此外,隨著掩蔽率的增加,效能差距會擴大。例如,在 75% 的掩蔽率下,naive 掩蔽會將 FID 得分降低到 16.5(越低越好),而本文方法可以達到 5.03,更接近沒有掩蔽的 FID 得分 3.79。
表 1 表明 layer-wise scaling 方法在擴散 transformer 的掩蔽訓練中具有更好的擬合效果。
比較不同的掩蔽策略。作者首先將本文方法與使用較大 patch 的策略進行比較。將 patch 大小從 2 增加到 4,相當於 75% 的 patch 掩蔽。與延遲掩蔽相比,其他方法表現不佳,分別僅達到 9.38、6.31 和 26.70 FID、Clip-FID 和 Clip-score。相比之下,延遲掩蔽分別達到 7.09、4.10 和 28.24 FID、Clip-FID 和 Clip-score。
下圖為延遲掩蔽 vs. 模型縮小以減少訓練成本的比較。在掩蔽率達到 75% 之前,作者發現延遲掩蔽在至少三個指標中的兩個方面優於網路縮小。但是,在極高的掩蔽率下,延遲掩蔽往往會實現較低的效能。這可能是因為在這些比率下掩蔽的資訊損失太高導致的。
表 5 提供了有關模型訓練超引數的詳細資訊。訓練過程分兩個階段。
計算成本。表 2 提供了每個訓練階段的計算成本明細,包括訓練 FLOP 和經濟成本。第 1 階段和第 2 階段訓練分別消耗了總計算成本的 56% 和 44%。模型在 8×H100 GPU 叢集上的總時鐘訓練時間為 2.6 天,相當於在 8×A100 GPU 叢集上為 6.6 天。
瞭解更多結果,請參考原論文。