JIT in MegEngine

MegEngine發表於2021-08-06

作者:王彪 | 曠視框架部異構計算組工程師

一、背景

什麼是天元

曠視天元(MegEngine)是一個深度學習框架,它主要包含訓練和推理兩方面內容。訓練側一般使用 Python 搭建網路;而推理側考慮到產品效能的因素,一般使用 C++ 語言整合天元框架。無論在訓練側還是推理側,天元都擔負著將訓練和推理的程式碼執行到各種計算後端上的任務。目前天元支援的計算後端有 CPU、GPU、ARM 和一些領域專用的加速器,覆蓋了雲、端、芯等各個場景。

天元主要有三大特徵:

1.訓推一體,不管是訓練任務還是推理任務都可以由天元一個框架來完成。

2.動靜結合,天元同時支援動態圖和靜態圖,並且動靜之間的轉換也非常方便。

3.多平臺的高效能支援。

圖 1. 天元架構

如圖 1 所示,我們可以看到天元提供了 Python 和 C++ 兩種介面。在圖表示上分為動態圖和靜態圖。運算層元件包括自動求導器、圖優化和圖編譯等。天元的執行時模組包括記憶體管理和計算排程,其中記憶體管理包括靜態記憶體管理和動態記憶體管理,以及亞線性記憶體優化技術。計算核心層包含了天元支援的所有計算後端,我們後續會開源出更多的計算後端。除此之外,天元還包含了一個高效能異構通訊庫,它一般會在多機多卡的場景下被用到。

圖 2. 計算圖

動態圖和靜態圖是相對的,在動態圖下是沒有計算圖的概念的。但在靜態圖下,天元會維護一張計算圖。如圖 2 所示為天元中的計算圖表示,圖中圓形表示運算元(operator),三角形表示輸入。在天元框架中,動態圖和靜態圖之間的轉換隻需要一條簡單的語句即可完成,如下方程式碼所示:

if __name__ == '__main__’: 
    gm = ad.GradManager().attach(model.parameters()) 
    opt = optim.SGD(model.parameters(), lr=0.0125, momentum=0.9, weight_decay=1e-4) 
   <em> # 通過 trace 轉換為靜態圖</em>
    @trace(symbolic=True) 
    def train(): 
        with gm: 
            logits = model(image) 
            loss = F.loss.cross_entropy(logits, label) 
            gm.backward(loss) 
        opt.step() 
        opt.clear_grad() 
        return loss 
    loss = train() 
    loss.numpy()

什麼是 AOT 和 JIT

AOT(Ahead Of Time) 和 JIT(Just In Time) 都是編譯中的概念。以傳統的 C/C++ 語言為例,我們寫完程式碼之後,一般會通過編譯器編譯生成可執行檔案,然後再執行該可執行檔案獲得執行結果。如果我們將從原始碼編譯生成可執行檔案的過程稱為 build 階段,將執行可執行檔案叫做 runtime 階段的話,JIT 是沒有build 階段的,它只有 runtime 階段。JIT 一般被用在解釋執行的語言如 Python 中,JIT 會在程式碼執行的過程中檢測熱點函式,隨後對熱點函式進行重編譯,下次執行時遇到熱點函式則直接執行編譯結果即可。這樣做可以顯著加快程式碼執行的速度。

什麼是 MLIR

隨著各種程式語言的出現,現代編譯器也日趨多樣化。特別是近年來隨著深度學習的興起,深度學習軟體框架和 AI 領域專用硬體呈爆發式增長。不斷增加的軟體框架和 AI 硬體之間逐漸形成了一個越來越大的溝壑,如何將框架層對深度學習模型的描述精準高效的翻譯成適應各類硬體的語言成為難點。MLIR(Multi-Level Intermediate Representation) 是一種可以在統一的基礎架構下滿足多樣化需求的混合 IR。MLIR 可以滿足包括但不限於以下的需求:

1.表達資料流圖(如靜態圖模式下的 MegEngine Graph)

2.表達對該圖做的優化和變換操作

3.進行各種運算元優化如運算元融合(kernel fusion)、迴圈融合、運算元分塊和記憶體格式(memory layout)轉換等

4.自動程式碼生成、顯式快取管理、自動向量化

作為一個公用的 IR,MLIR 具有非常優秀的表達能力和可擴充套件性。MLIR 可以表達圖層面的運算,同時可以表達傳統編譯器中的 IR 資訊,也可以表示硬體專用的運算。這種不同屬性,不同型別的運算的集合構成了 MLIR 中的方言(Dialect)。MLIR 還提供方便的機制實現不同方言之間的轉換(Lowering Down),因此 MLIR 的一個通用優化將會在多個方面產生收益。接入 MLIR 也將有更大可能享受到它的生態好處,包括效能和擴充套件性等方面。

二、動機

為什麼做

眾所周知,深度學習模型中有很多 element-wise 操作,例如加減乘除算術運算和神經網路中的啟用函式一般都是 element-wise 操作。天元將 element-wise 操作分為一元操作、二元操作和多元操作。一元操作主要有 RELU、ABS、SIN 和 COS 等等;二元操作有加法、減法、乘法和除法以及 MAX 等;多元操作有 FUSE-MUL-ADD3 和 FUSE-MUL-ADD4 等,它們分別計算的是 “ab+c” 以及 “ab+c*d”。

表 1 卷積神經網路中的 element-wise 操作

element-wise 操作在卷積神經網路中所佔的地位不可忽視。如表 1 所示,我們選擇公開的卷積神經網路訓練模型,以純 device kernel 的執行時間為基準統計卷積神經網路中element-wise 操作的重要性。

首先可以清晰的看到,element-wise 的計算量的佔比相比於執行時間佔比要低 1-2 個數量級。它的計算量佔的非常少,但是它的執行時間佔比非常多,這個結論是比較反直覺的。並且隨著 batch size 的增加,這個現象也越來越明顯。這是因為 element-wise 操作計算量較低但是訪存量較高,即計算訪存比較低,是一種典型的訪存受限 (memory bound) 的操作。以 “a+b” 為例,我們首先要將 a 讀到記憶體中,再將 b 讀到記憶體中,做完一次加法之後,我們將結果 c 再寫到記憶體中。整個過程要經過兩次讀和一次寫才能完成一次計算,所以它的計算反應訪存比非常低。針對訪存受限的操作,優化計算時間實際上是沒有沒有太多的意義的,而應該集中精力優化訪存,訪存優化的常見的優化手段是融合 (fusion)。如果我們能將網路中連在一起的 element-wise 操作融合成一個運算元,則將減少 element-wise 操作的訪存量,增加計算訪存比從而加速網路的整體效能。

為什麼用 JIT 做

卷積神經網路有兩個鮮明的特徵。一個是靜態圖模式下的模型訓練過程中模型的結構一般是不會變的跑;另一個是在模型訓練的過程中,一般會經過很多個 iter/min-batch,不同的 iter/min-batch 之間輸入張量形狀(tensor shape)一般也不會變。基於卷積神經網路的這兩個特徵,我們決定應用 JIT 技術,原因如下:

1.只需要在首次執行的時候編譯一次,隨後的不同 iter/mini-batch 可以重用第一次編譯出來的結果。

2.JIT 具有較強的可移植性,因為它在執行時獲取平臺資訊,然後生成可以在該平臺執行的程式碼。

3.JIT 可以解決 element-wise 模式組合爆炸的問題。

三、技術方案

我們通過 Element-wise Fusion 可以把多個 element-wise 操作融合成一個,減少了運算元數量也就減少了運算元之間的讀寫次數。如圖 3 所示計算圖算的是 “a*b+c”,它需要 4 次讀,2 次寫。4 次讀分別是乘法在讀 a 和 b 兩個輸入,乘法其實還要寫一個隱藏的輸出,加法會讀乘法的輸出作為輸入,以及加法讀 c 作為輸入。兩次寫分別是乘法和加法對它們結果的兩次寫操作,總共加起來是 4 次讀,2 次寫。

我們將其融合成一個運算元 FUSE_MUL_ADD3,由於天元現在已經支援 FUSE_MUL_ADD3 這個 element-wise 模式,所以我們可以直接做模型手術將計算圖從圖 3 左側形式轉到圖 3 右側形式。對於融合之後的計算圖,我們只需要 3 次讀和 1 次寫就可以完成等價計算,相比於融合前減少了 1 次讀和 1 次寫操作。

圖3 融合優化減少訪存次數

我們無法預測使用者將搭出來怎樣的一張計算圖,考慮圖 4 所示的計算圖,其中 element-wise 的個數和順序都不固定,顯然我們不可能提前將各種 element-wise 模式的組合都寫進天元的。在這種情況下,天元會建立一個虛擬的運算元來表示整個可被融合的子圖。有了虛擬運算元的存在,接下來我們還要解決兩個問題,一個是用虛擬運算元替換原始計算圖中可以被融合的子圖,這個工作會在圖優化階段做;另一個是我們要動態生成虛擬運算元的程式碼並執行。如果我們解決了這兩個問題,我們就解決了整個問題。

圖4 子圖融合優化

圖優化

為了將一張計算圖中的可被融合的子圖融合成一個運算元,天元將進行檢測(detection)和融合(fusion)兩步操作,如下步驟 1-3 屬於檢測,步驟 4 則屬於融合:

1.對原始計算圖進行檢測後生成 internal graph generator,一個 internal graph generator 對應一個唯一的子圖

2.internal graph generator 稍後會生成 internal graph

3.由 internal graph 建立 JITExcutor 運算元

4.將 JITExcutor 寫回原始的計算圖

檢測

檢測演算法的主要功能是找出可以被融合的子圖。為了方便描述,設 G 是計算圖,opr 是圖 G 中的運算元,var 是 opr 的輸入和輸出。檢測演算法的輸入是原始的計算圖 G,輸出是一個雜湊表 M,表中存放的是檢測出的可被融合子圖的輸出 var(記作 endpoint)與其對應的 internal graph generator。演算法步驟如下:

1.按照逆拓撲序列遍歷圖 G 中的運算元 opr

2.如果 opr 不是 Elemwise/PowC/TypeCvt/Reduce/Dimshuffle/JITExecutor,返回步驟1

3.如果 opr 的 input/output 資料型別不是 float32/float16,返回步驟1

4.process_opr(opr)

5.轉到步驟 1

圖5 process_opr 流程圖

拓撲序列要求所有的父節點要先於它的子節點被訪問到,與之對應的,逆拓撲序列就是所有的子節點要先於它的父節點被訪問到。演算法第 1 步中我們之所以按照逆拓撲序列遍歷計算圖,是因為要保證遍歷到某個 opr 時,它的子節點都已經被遍歷到了。這樣演算法可以檢視該 opr 的所有的子節點是不是都在同一張子圖中,如果是,那麼當前 opr 就有很大的可能也在該子圖中。演算法的第 2 步和第 3 步實際上說明了天元中的 JIT 的限制。目前天元 JIT 僅支援 Elemwise/PowC/TypeCvt/Reduce/Dimshuffle 這幾種 opr,而且只支援輸入輸出是 float32/float16 的資料型別。第 4 步詳細流程如圖 5 所示。需要注意的是演算法會經過如下三個判斷語句:

1.該 opr 的子節點是不是都已經在當前的這張子圖中了?

2.該 opr 的輸出的計算節點(compute node)是不是跟子圖匹配?天元支援跨計算節點的計算圖,例如計算圖中一些 opr 可以執行在 CPU 上,一些 opr 可以執行在 GPU上。但目前天元不支援跨計算節點融合。

3.該 opr 的輸出的 shape 是不是跟子圖匹配?因為最終生成的程式碼本質上是一個大的迴圈,迴圈的維度就是 opr 輸出的 shape,所以如果 shape 不匹配是不能被融合的。

圖 6 檢測演算法檢測出的可被融合的子圖

圖 6 中虛線框出來的即為檢測演算法檢測出的兩個可被融合的子圖。

融合

融合演算法的主要功能是將檢測出來的子圖融合成一個運算元。融合演算法的輸入是原始的計算圖和檢測演算法輸出的那張雜湊表 M,它的輸出是經過融合的計算圖 G‘。演算法流程如下:

1.按照拓撲序列遍歷圖 G 中的運算元 opr

2.若 opr 的輸入 var 不是 endpoint, 返回步驟 1

3.從 M 中拿到 var 對應的 internal graph generator, 生成 internal graph

4.從 internal graph 建立 JITExecutor

5.寫回原始的計算圖 G

6.轉到步驟 1 步驟 2 中如果一個 opr 的輸入 var 不是 endpoint 則表示它是一個子圖中的中間節點而不是子圖的輸出節點。步驟 3 中從 internal graph generator 到 internal graph 需要將子圖的輸入 var 替換為一個新的 opr JITPlaceholder。JITPlaceholder 中會存諸如子圖的輸入順序這些額外資訊,因為某些 element-wise 操作是對輸入順序敏感的。例如 a 對 b 取餘和 b 對 a 取餘顯然具有不同的語義。

圖 7 融合後的計算圖

圖 7 即為經過融合演算法之後的計算圖,截止到目前為止,我們已經完成了圖優化方面的所有工作。

圖編譯

經過圖優化之後,我們成功的將計算圖中可被融合的子圖融合成為一個新的運算元,剩下的工作就是為這個新的運算元生成程式碼了。JITExecutor 運算元的執行時程式碼非常簡單,先判斷一下當前的可執行物件是不是已經存在,如果不存在則先編譯出一個可執行物件,如已存在則直接執行。這段程式碼在執行時才會被執行到,所以稱之為 JIT。當前天元支援三種 JIT 編譯器後端,分別是 NVRTC(支援英偉達 GPU),Halide 和 MLIR。其中後兩個編譯後端支援的平臺眾多,但是 MLIR 具有更優秀的表達能力和擴充套件性,所以我們接下來以 MLIR 為例介紹程式碼生成、編譯和執行的過程。

要想使用 MLIR 作為編譯後端,首先我們需要定義和實現天元自己的方言(MGE Dialect),隨後我們將 MGE Dialect 轉換到 MLIR 既有的 Dialect 上,接下來的絕大部分工作都可以複用 MLIR 中的基礎元件和工具完成。圖 8 描述了 CPU 和 GPU 上大概的執行流程。

圖 8 JIT 編譯器工作流

天元首先將 JITExecutor 運算元內部的 internal graph 翻譯成 MGE Dialect。在 CPU 上,MGE Dialect 會先 Lowering 到 Affine Dialect 上,然後會通過 LLVM 的元件 Lowering 到 LLVM Dialect 上,LLVM Dialect 可以被直接翻譯成 LLVM IR。在這一步之後,其他優化工作都可以直接複用 LLVM 的基礎元件。最後天元使用 MLIR ExecutionEngine 執行 LLVM IR 生成的程式碼。在 GPU 上,天元會先將 MGE Dialect Lowering 到 GPU Dialect上,隨後 Lowering 到 NVVM Dialect,NVVM 會被翻譯成 PTX 彙編。最後通過英偉達提供的 CUmodule 和 CUfunction 兩個機制執行。

四、實驗和分析

首先參考這篇文件 在天元中開啟 JIT 支援。本次實驗選了 resnet50, mobilenetV2 和 vgg16 三個業界廣泛使用的模型,batch size 分別設定了 1, 8 和 16。測試硬體環境為 NVIDIA T4,軟體環境為 MegEngine v1.2.0。

圖 9 開啟 JIT 相比於不開 JIT 的加速比

由圖 9 可知,和不開啟 JIT 支援相比,開啟 JIT 支援後 resnet50 最高可以獲得 16% 的加速比,mobilenet V2 則能獲得 6% 到 7% 的加速比,而 vgg16 其實上沒有明顯加速效果。這是因為 vgg16 模型很大,可以被優化的 element-wise 操作比較少。JIT 的優化效果跟具體的模型是有緊密關係的。

圖 10 JIT 編譯耗時

如果開啟了 JIT 支援,那麼天元首次執行的時候會有一次 JIT 編譯的過程。JIT 編譯耗時跟具體的編譯的後端以及模型有關,如圖 10 所示 resnet50 耗時 2.7 毫秒,mobilenetV2 耗時 3.9 毫秒。

五、總結和展望

本篇文章介紹了天元使用 JIT 實現將任意多個相鄰的 element-wise 運算元融合成一個運算元的優化。我們在 T4 上用 MegEngine v1.2.0 實驗,相比於優化前,resnet 50 最高可以獲得 16% 的加速比。

以此為基,展望未來我們可能做的事情如下:

1.將 JIT 編譯的結果先離線儲存,線上直接將線下編譯好的可執行物件讀進記憶體。這種做法可以解決線上第一次執行慢的問題,但它可能會損失一部分可移植性,因為在一種裝置上編譯產生的可執行物件一般不能適配所有線上裝置。 2.JIT 支援更多的運算元。 3.JIT支援更多的資料型別,天元 JIT 優化暫時只支援 float32/float16 這兩種資料型別。 4.動態圖 JIT,也就是傳統意義上的檢測熱點程式碼,重編譯後再執行。

相關文章