Meta 內部都在用的 FX 工具大起底:利用 Graph Transformation 最佳化 PyTorch 模型

超神經HyperAI發表於2022-11-30
PyTorch 中的 graph mode 在效能方面表示更為出色,本文介紹 Torch.FX 這個強大工具,可以捕捉和最佳化 PyTorch 程式 graph。

一、簡介

PyTorch 支援兩種執行模式:eager mode 和 graph mode。

eager mode 中,模型中的運算子在讀取時會立即執行,它易於使用,對機器學習從業者更友好,因此被設定為預設的執行模式。

graph mode 中,運算子先被合成一個 graph,然後作為一個整體進行編譯和執行,它的效能更高,因此在實際生產中大量使用。

具體來說,graph mode 支援運算元融合,兩個運算元透過合併,可以降低或本地化記憶體讀取以及核心啟動總開銷。

融合可以是橫向 (horizontal) 的:採取應用於多個 operand 的單一操作(如 BatchNorm),並將這些 operand 合併到一個陣列中。

融合也可以是縱向 (vertical) 的:將一個核心與另一個核心合併,後者需要使用第一個核心的輸出(如 ReLU 後接卷積)。

Torch.FX(縮寫為 FX)是一個公開可用的工具包,作為 PyTorch 軟體包的一部分,支援 graph mode 的執行。它可以:

  1. 從 PyTorch 程式中獲取 graph
  2. 允許開發者在獲取的 graph 上編寫 transformation**

Meta 內部先前已經在用 FX 來最佳化生產模型 (production model) 的訓練吞吐量 (training throughput)。本文將透過介紹 Meta 開發的基於 FX 的最佳化,來展示利用圖結構轉換 (graph transformation) 最佳化 PyTorch 部署模型效能的方法

二、背景

embedding table 廣泛存在於推薦系統中,本節將介紹 FX 和 embedding table 的背景知識。

2.1. FX

圖 1 是一個簡單示例,演示瞭如何用 FX 轉換 PyTorch 程式,它包含三個步驟:

  • 從程式中獲取 graph
  • 修改 graph(在本例中,我們用 GELU 代替 RELU)
  • 從修改後的 graph 中生成一個新程式

圖1:在 PyTorch 模組中用 GELU 取代 RELU 的 FX

FX API 為檢查和轉換 PyTorch 程式 graph 還提供了許多其他功能。

2.2. embedding table

圖2:批尺寸=1 的稀疏特徵 embedding table 示意圖

在推薦系統中,稀疏特徵(例如,User ID,Story ID)由 embedding table 表示。

embedding table E 是一個 HxD 矩陣,其中 H 是雜湊大小,D 是嵌入向量維度。E 的每一行都是一個浮點數向量。

feature hashing 的作用是將一個稀疏特徵對映到 E的索引列表中,例如 [S1,S2,...,Sk],其中 0≤Si<H。它的輸出值計算為 f(E[S1], E[S2],...,E[Sk]),其中 E[Si] 是 Si 行的向量,f 是池化函式,通常是 sum,average,max 三個函式之一。

為了充分利用 GPU,稀疏特徵通常為批處理。批處理中的每個實體都有自己的索引列表。如果一個批次有 B 個實體,可以簡單理解為一個表徵有 B 個索引列表。

更為嚴謹的表示方法是將 B 個索引列表合併成一個索引列表,並新增一個索引長度的列表(該批中的每個實體都有一個長度 length)。

例如,如果一批包含 3 個實體,其索引列表如下:

  • Entity 1: indices = [10, 20]
  • Entity 2: indices = [5, 9, 77, 81]
  • Entity 3: indices = [15, 20, 45]

則完整批尺寸的 indice 和 length 將是:

  • Indices = [10, 20, 5, 9, 77, 81, 15, 20, 45]
  • Lengths = [2, 4, 3]

而整個 batch 的 embedding table 查詢,輸出為是一個 BxD 矩陣。

三、3 種 FX Transformation

PyTorch 更新了 3 個 FX transformation,以加速對 embedding table 的訪問,本節將逐一介紹。

下文 3.1 關於將多個小輸入張量結合成一個大張量的轉換;3.2 關於將多個平行計算鏈融合成一個計算鏈的轉換;3.3 關於將通訊與計算重疊的轉換。

3.1 結合輸入稀疏特徵

batch 中的每個輸入稀疏特徵,都可以表示為兩個列表:一個索引列表和一個 B length 列表,其中 B 表示批尺寸。

在 PyTorch 中,這兩個列表都可以以張量的形式存在。當 PyTorch 模型在 GPU 上執行時,embedding table 通常儲存在 GPU 記憶體中(它更接近 GPU,讀寫頻寬比 CPU 記憶體更高)。

需要使用輸入稀疏特徵時,兩個張量都要先從 CPU 複製到 GPU。然而每個主機到裝置的記憶體複製都需要啟動核心,這對於實際的資料傳輸來說,會更加耗費時間。

如果一個模型使用了多個輸入稀疏特徵,這種複製可能成為效能瓶頸(例如,1000 個輸入稀疏特徵將需要從主機到裝置複製 2000 個張量)。

一個減少主機到裝置 memcpy 數量的最佳化方法,就是在多個輸入稀疏特徵傳送到裝置之前,先將其進行組合。

例如,給定以下三個輸入特徵:

  • Feature_A: indices = [106, 211, 7], lengths = [2, 1]
  • Feature_B: indices = [52, 498, 616, 870, 1013], lengths = [3, 2]
  • Feature_C: indices = [2011, 19, 351, 790], lengths = [1, 3]

組合後的形式為:

Features_A_B_C: indices = [106, 211, 7, 52, 498, 616, 870, 1013, 2011, 19, 351, 790], lengths = [2, 1, 3, 2, 1, 3]

所以不需要從主機到裝置複製 3x2=6 個張量,只需要複製 2 個張量。

圖 3(b) 描述了這種最佳化的實現,它包含兩個元件:

  • CPU 端:輸入 pipeline 被修改為將所有稀疏特徵的 indices 組合成一個張量,所有 length 組合成另一個張量。然後將這兩個張量複製到 GPU 上。
  • GPU 端:使用 FX,在模型 graph 中插入一個Permute_and_Split 運算元,從合併的張量中恢復單個特徵 indices 和 length 張量,並將其傳送至下游的相應節點。


最佳化前:兩個張量都要從 CPU 複製到 GPU


最佳化後:將輸入稀疏特徵進行組合

3.2 從訪問 embedding table 開始的計算鏈橫向融合

在一個生產模型中,每個 GPU 上有 10 個 embedding table 很常見。出於效能方面的考慮,對這些 table 的查詢被分到一組,這樣它們的輸出就被串聯在一個大張量中(見圖 4(a)中的紅色部分)。

為了對單個特徵輸出進行計算,使用 Split 運算元將大張量分成 N 個小張量(其中 N 為特徵的數量),然後將所需的計算應用於每個張量。

如圖 4(a) 所示,應用於每個特徵輸出 O 的計算是Tanh(LayerNorm(O))。所有的計算結果都被串聯成一個大的張量,然後傳遞給下游的運算元(圖 4(a) 中的 Op1)。

這裡主要的 runtime cost 是 GPU 核心啟動的開銷。例如,圖 4(a) 中的 GPU 核心的啟動次數為 2*N+3(圖中的每個橢圓都表示一個 GPU 核心)。這會影響效能,因為 LayerNorm 和 Tanh 在 GPU 上的執行時間,與它們的核心啟動時間相比很短。

此外,Split 運算元可能會建立一個額外的嵌入向量輸出張量的副本,消耗額外的 GPU 記憶體。

用 FX 來實現一種叫做橫向融合 (horizontal fusion) 的最佳化,可以大大減少 GPU 核心的啟動次數(在這個例子中,最佳化後的 GPU 核心啟動次數為 5,見圖 4(b))。

使用 Add_middle_dim 運算元代替顯式 Split,將 shape 為 (B, NxD) 的 2D 嵌入張量重塑為 shape 為 (B, N, D) 的 3D 張量。接下來將一個單一的 LayerNorm 應用到它的最後一維。對 LayerNorm 的結果應用一個 Tanh。最後,用 Remove_middle_dim 運算元將 Tanh 的結果恢復成 2D 張量。

由於 Add_middle_dim 和 Remove_middle_dim 只是重塑張量,並沒有建立額外的副本,所以也可以減少 GPU 記憶體的消耗。


最佳化前:所有輸出被串聯到一個大張量中


進行橫向融合最佳化後

3.3 計算與通訊間的重疊 (overlap)

面向投產的推薦模型的訓練,通常是在分散式 GPU 系統上完成的。由於每個 GPU 的裝置記憶體容量不足以容納模型中的所有 embedding table,因此需要將其分佈在多個 GPU 上。

在訓練步驟中,GPU 需要從其他 GPU 上的 embedding table 中讀取/寫入特徵值。這被稱為 all-to-all 通訊,可能是影響效能的重要原因。

透過 FX 實現一個 transformation,可以將計算與 all-to-all 通訊重疊。圖 5(a) 顯示了一個具備嵌入向量 table 訪問 (EmbeddingAllToAll) 及其他運算元的模型 graph 例項。如圖 5(b) 所示,在沒有任何最佳化的情況下,它們會在一個 GPU 流上順序執行。

使用FX將 EmbeddingAllToAll 分成 EmbeddingAllToAll_Request和EmbeddingAllToAll_Wait,並在它們之間安排獨立的運算元。

圖5:計算與通訊的重疊

3.4 總結


表1:本節討論的最佳化及解決的相應效能瓶頸

為了發現哪些模型會從這些 transformation 中受益,開發人員對 MAIProf 收集的執行在 Meta 資料中心的模型的效能資料進行分析。得出與 eager mode 相比,這些 transformation 在一組生產模型上實現了 2-3 倍的速度提升。

四、結語

從效能角度考量,PyTorch 中的 graph mode 比生產環境中使用的 eager mode 更受歡迎。FX 是一個強大的工具,可以捕捉和最佳化 PyTorch 程式 graph。本文展示了三種 FX transformation,用於最佳化 Meta 內部的生產推薦模型。

最後希望更多 PyTorch 開發者可以使用 graph transformation 來提升模型的效能。

—— 完 ——

相關文章