深度學習運算元優化-FFT

MegEngine發表於2021-08-10

作者:嚴健文 | 曠視 MegEngine 架構師

背景

在數字訊號和數字影像領域, 對頻域的研究是一個重要分支。
我們日常“加工”的影像都是畫素級,被稱為是影像的空域資料。空域資料表徵我們“可讀”的細節。如果我們將同一張影像視為訊號,進行頻譜分析,可以得到影像的頻域資料。 觀察下面這組圖 (來源),頻域圖中的亮點為低頻訊號,代表影像的大部分能量,也就是影像的主體資訊。暗點為高頻訊號,代表影像的邊緣和噪聲。從組圖可以看出,Degraded Goofy 與 Goofy 相比,近似的低頻訊號保留住了 Goofy 的“輪廓”,而其高頻訊號的增加使得背景噪點更加明顯。頻域分析使我們可以瞭解影像的組成,進而做更多的抽象分析和細節處理。

Goofy and Degraded Goofy

實現影像空域和頻域轉換的工具,就是傅立葉變換。由於影像資料在空間上是離散的,我們使用傅立葉變換的離散形式 DFT(Discrete Fourier Transform)及其逆變換 IDFT(Inverse Discrete Fourier Transform)。Cooley-Tuckey 在 DFT 的基礎上,開發了更快的演算法 FFT(Fast Fourier Transform)。

DFT/FFT 在數字影像領域還有一些延伸應用。比如基於 DFT 的 DCT(Discrete Cosine Transform, 離散餘弦變換)就用在了影像壓縮 JPEG 演算法 (來源) 和影像水印演算法(來源)。

JPEG 編碼是通過色彩空間轉換、抽樣分塊、DCT 變換、量化編碼實現的。其中 DCT 變換的使用將影像低頻資訊和高頻資訊區分開,在量化編碼過程中壓縮了少量低頻資訊、大量高頻資訊從而獲得尺寸上壓縮。從貓臉圖上可看出隨著壓縮比增大畫質會變差,但是主體資訊還是得以保留。

貓臉不同 jpeg 畫質(壓縮比)

影像水印演算法通過 DCT 將原圖轉換至頻域,選取合適的位置嵌入水印影像資訊,並通過 IDCT 轉換回原圖。這樣對原影像的改變較小不易察覺,且水印通過操作可以被提取。

DFT/FFT 在深度學習領域也有延伸應用。 比如利用 FFT 可以降低卷積計算量的特點,FFT_Conv 演算法也成為常見的深度學習卷積演算法。本文我們就來探究一下頻域演算法的原理和優化策略。

DFT 的原理及優化

公式

無論是多維的 DFT 運算,還是有基於 DFT 的 DCT/FFT_Conv, 底層的計算單元都是 DFT_1D。 因此,DFT_1D 的優化是整個 FFT 類運算元優化的基礎。
DFT_1D 的計算公式:
\(X_{k}=\sum_{n=0}^{\mathrm{N}-1} x_{n} e^{-j 2 \pi k \frac{n}{N}} \quad k=0, \ldots, N-1\)

其中 \(x_{n}\)為長度為 N 的輸入訊號,\(e^{-j 2 \pi k \frac{n}{N}}\)是 1 的 N 次根, \(X_{k}\)為長度為 N 的輸出訊號。
該公式的矩陣形式為:

\(\left[\begin{array}{c}X(0) \\ X(1) \\ \vdots \\ X(N-1)\end{array}\right]=\left[W_{N}^{n k}\right]\left[\begin{array}{c} \left.x(0\right) \\ x(1) \\ \vdots \\ x(N-1)\end{array}\right]\)

單位復根的性質

DFT_1D 中的\(W_{N}^{nk} = e^{-j 2 \pi k \frac{n}{N}}\)是 1 的單位復根。直觀地看,就是將複平面劃分為 N 份,根據 k * n 的值逆時針掃過複平面的圓周。

單位復根有著週期性和對稱性,我們依據這兩個性質可以對 W 矩陣做大量的簡化,構成 DFT_1D 的快速演算法的基礎。
週期性:\(W_{N}^{k +N}=W_{N}^{k}\)
對稱性:\(W_{N}^{k+N / 2}=-W_{N}^{k}\)

Cooley-Tuckey FFT 演算法

DFT_1D 的多種快速演算法中,使用最頻繁的是 Cooley-Tuckey FFT 演算法。演算法採用用分治的思想,將輸入尺寸為 N 的序列,按照不同的基 radix,分解為 N/radix 個子序列,並對每個子序列再劃分,直到不能再被劃分為止。每一次劃分都可以得到一級 stage,將所有的級自下而上組合在一起,計算得到最後的輸出序列。
這裡以 N = 8, radix=2 為例展示推理過程。
其中\(x(k)\)為 N=8 的序列, \(X^{F}(k)\)為 DFT 輸出序列。
根據 DFT 的計算公式
\(X^{F}(k)=W_{8}^{0} x_{0}+W_{8}^{k} x_{1}+W_{8}^{2 k} x_{2}+W_{8}^{3k} x_{3}+W_{8}^{4k} x_{4} + W_{8}^{5k} x_{5}+W_{8}^{6k} x_{6} +W_{8}^{7k} x_{7}\)

根據奇偶項拆開,分成兩個長度為 4 的序列\(G(k)\), \(H(k)\)

$ X^{F}(k)= W_{8}^{0} x_{0}+W_{8}^{2 k} x_{2}+W_{8}^{4 k} x_{4}+W_{8}^{6 k} x_{6}$

\(+W_{8}^{k}\left(W_{8}^{0} x_{1}+W_{8}^{2 k} x_{3}+W_{8}^{4 k} x_{5}+W_{8}^{6 k} x_{7}\right)\)

\(=G^{F}(k)+W_{8}^{k} H^{F}(k)\)

\(X^{F}(k+4)=W_{8}^{0} x_{0}+W_{8}^{2(k+4)} x_{2}+W_{8}^{4(k+4)} x_{4}+W_{8}^{6(k+4)} x_{6}\)
\(+W_{8}^{(k+4)}\left(W_{8}^{0} x_{1}+W_{8}^{2(k+4)} x_{3}+W_{8}^{4(k+4)} x_{5}+W_{8}^{6(k+4)} x_{7}\right)\)
\(=G^{F}(k)+W_{8}^{k+4} H^{F}(k)\)
\(=G^{F}(k)-W_{8}^{k} H^{F}(k)\)

\(G^{F}(k)\)\(H^{F}(k)\)\(G(k)\)\(H(k)\)的 DFT 結果。\(G^{F}(k)\)\(H^{F}(k)\)乘以對應的旋轉因子\(W_{8}^{k}\),進行簡單的加減運算可以得到輸出\(X^{F}(k)\)
同理,對\(G(k)\)\(H(k)\)也做一樣的迭代,\(A(k)\)\(B(k)\), \(C(k)\), \(D(k)\) 都是 N=2 的序列,用他們的 DFT 結果進行組合運算可以得到\(G^{F}(k)\)\(H^{F}(k)\)

\(\begin{aligned} &G^{F}(k)=A^{F}(k) + W_{4}^{k}B^{F}(k)\\ \end{aligned}\)
\(\begin{aligned} &G^{F}(k+2)=A^{F}(k)-W_{4}^{k}B^{F}(k)\\ \end{aligned}\)
\(\begin{aligned} &H^{F}(k)=C^{F}(k)+W_{4}^{k}D^{F}(k)\\ \end{aligned}\)
\(\begin{aligned} &H^{F}(k+2)=C^{F}(k)-W_{4}^{k}D^{F}(k)\\ \end{aligned}\)

計算 N=2 的序列\(A^{F}(k)\), \(B^{F}(k)\), \(C^{F}(k)\), \(D^{F}(k)\), 因為\(k=0\),旋轉因子\(W_{2}^{0}\)= 1。只要進行加減運算得到結果。
\(\left[\begin{array}{l} A^{F}(0) \\ A^{F}(1) \end{array}\right]=\left[\begin{array}{ll} 1 & 1 \\ 1 & -1 \end{array}\right]\left[\begin{array}{l} x_{0} \\ x_{4} \\ \end{array}\right]\)

\(\left[\begin{array}{l} B^{F}(0) \\ B^{F}(1) \end{array}\right]=\left[\begin{array}{ll} 1 & 1 \\ 1 & -1 \end{array}\right]\left[\begin{array}{l} x_{2} \\ x_{6} \\ \end{array}\right]\)

\(\left[\begin{array}{l} C^{F}(0) \\ C^{F}(1) \end{array}\right]=\left[\begin{array}{ll} 1 & 1 \\ 1 & -1 \end{array}\right]\left[\begin{array}{l} x_{1} \\ x_{5} \\ \end{array}\right]\)

\(\left[\begin{array}{l} D^{F}(0) \\ D^{F}(1) \end{array}\right]=\left[\begin{array}{ll} 1 & 1 \\ 1 & -1 \end{array}\right]\left[\begin{array}{l} x_{3} \\ x_{7} \\ \end{array}\right]\)

用演算法圖形表示,每一層的計算會產生多個蝶形,因此該演算法又被稱為蝶形演算法。
這裡我們要介紹碟形網路的基本組成,對下文的分析有所幫助。

N=8 碟形演算法圖

N=8 的計算序列被分成了 3 級,每一級 (stage) 有一個或多個塊 (section),每個塊中包含了一個或者多個蝶形(butterfly), 蝶形的計算就是 DFT 運算的 kernel。
每一個 stage 的計算順序:

  • 取輸入
  • 乘以轉換因子
  • for section_num, for butterfly_num,執行 radixN_kernel
  • 寫入輸出。

看 N=8 的蝶形演算法圖,stage = 1 時,運算被分成了 4 個 section,每個 section 的 butterfly_num = 1。stage = 2 時,section_num = 2,butterfly_num = 2。 stage = 3 時,section_num = 1,butterfly_num = 4。
可以觀察到,從左到右過程中 section_num 不斷減少,butterfly_num 不斷增加,蝶形群在“變大變密”,然而每一級總的碟形次數是不變的。
實際上,對於長度為 N,radix = r 的演算法,我們可以推得到:

\(\text { Sec_num }=N / r^{S}\)
\(\text { Butterfly_num }= r^{S-1}\)
\(\text { Sec_stride }=r^{S}\)
\(\text { Butterfly_stride }=1\)

S 為當前的 stage,sec/butterfly_stride 是每個 section/butterfly 的間隔。

這個演算法可以將複雜度從 O(n^2) 下降到 O(nlogn),顯得高效而優雅。我們基於蝶形演算法,對於不同的 radix 進行演算法的進一步劃分和優化,主要分為 radix - 2 的冪次的和 radix – 非 2 的冪次兩類。

radix-2 的冪次優化

DFT_1D 的 kernel 即為矩陣形式中的\(W_{N}^{nk}\)矩陣,我們對 radix_2^n 的 kernel 進行分析。

背景裡提到, DFT 公式的矩陣形式為:
\(\left[\begin{array}{c}X(0) \\ X(1) \\ \vdots \\ X(N-1)\end{array}\right]=\left[W_{N}^{n k}\right]\left[\begin{array}{c} \left.x(0\right) \\ x(1) \\ \vdots \\ x(N-1)\end{array}\right]\)
其中\(x(0)\) ~\(x(N-1)\)為乘以旋轉因子\(W_{N}^{kn}\)後的輸入

當 radix = 2 時,由於\(W_{2}^1\)  = -1, \(W_{2}^2\) = 1, radix_2 的 DFT 矩陣形式可以寫為:

\(\left[\begin{array}{c}\mathrm{X}_{\mathrm{k}} \\ \mathrm{X}_{\mathrm{k}+\mathrm{N} / 2}\end{array}\right]\) \(=\left[\begin{array}{cc}1 & 1 \\ 1 & -1\end{array}\right]\left[\begin{array}{l}\mathrm{W}_{\mathrm{N}}^{0} \mathrm{A}_{\mathrm{k}} \\ \mathrm{W}_{\mathrm{N}}^{\mathrm{k}} \mathrm{B}_{\mathrm{k}}\end{array}\right]\)

當 radix = 4 時,由於\(W_{4}^1\) = -j, \(W_{4}^2\) = -1, \(W_{4}^3\) = j, \(W_{4}^4\)= 1,radix_4 的 DFT 矩陣形式可以寫為:

\(\left[\begin{array}{c}\mathrm{X}_{\mathrm{k}} \\ \mathrm{X}_{\mathrm{k}+\mathrm{N} / 4} \\ \mathrm{X}_{\mathrm{k}+\mathrm{N} / 2} \\ \mathrm{X}_{\mathrm{k}+3 \mathrm{~N} / 4}\end{array}\right]=\left[\begin{array}{cccc}1 & 1 & 1 & 1 \\ 1 & -\mathrm{j} & -1 & \mathrm{j} \\ 1 & -1 & 1 & -1 \\ 1 & \mathrm{j} & -1 & -\mathrm{j}\end{array}\right]\left[\begin{array}{c}\mathrm{W}_{\mathrm{N}}^{0} \mathrm{A}_{\mathrm{k}} \\ \mathrm{W}_{\mathrm{N}}^{\mathrm{k}} \mathrm{B}_{\mathrm{k}} \\ \mathrm{W}_{\mathrm{N}}^{2 \mathrm{k}} \mathrm{C}_{\mathrm{k}} \\ \mathrm{W}_{\mathrm{N}}^{3 \mathrm{k}} \mathrm{D}_{\mathrm{k}}\end{array}\right]\)

同理推得到 radix_8 的 kernel 為:

\(\left[\begin{array}{cccccccc}1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 \\ 1 & \mathrm{~W}_{8}^{1} & -j & \mathrm{~W}_{8}^{3} & -1 & -\mathrm{W}_{8}^{1} & j & -\mathrm{W}_{8}^{3} \\ 1 & -j & -1 & j & 1 & -j & -1 & j \\ 1 & \mathrm{~W}_{8}^{3} & j & \mathrm{~W}_{8}^{1} & -1 & -\mathrm{W}_{8}^{3} & -j & -\mathrm{W}_{8}^{1} \\ 1 & -1 & 1 & -1 & 1 & -1 & 1 & -1 \\ 1 & -\mathrm{W}_{8}^{1} & -j & -\mathrm{W}_{8}^{3} & -1 & \mathrm{~W}_{8}^{1} & j & \mathrm{~W}_{8}^{3} \\ 1 & j & -1 & -j & 1 & j & -1 & -j \\ 1 & -\mathrm{W}_{8}^{3} & j & -\mathrm{W}_{8}^{1} & -1 & \mathrm{~W}_{8}^{3} & -j & \mathrm{~W}_{8}^{1}\end{array}\right]\)

我們先來看訪存,現代處理器對於計算效能的優化要優於對於訪存的優化,在計算和訪存相近的場景下, 訪存通常是效能瓶頸。

DFT1D 中,對於不同基底的演算法 r-2/r-4/r-8, 每一個 stage 有著相等的存取量:2 * butterfly_num * radix = 2N,  而不同的基底對應的 stage 數有著明顯差異(\(\log_2N\) vs \(\log_4N\) vs \(\log_8N\))。

因此對於 DFT, 在不顯著增加計算量的條件下, 選用較大的 kernel 會在訪存上取得明顯的優勢。觀察推導的 kernel 圖, r-2 的 kernel 每個蝶形對應 4 次訪存操作和,2 次複數浮點加減運算。r-4 的 kernel 每個蝶形演算法對應 8 次 load/store、8 次複數浮點加減操作(合併相同的運算),在計算量略增加的同時 stage 由 \(\log_2N\) 下降到 \(\log_4N\) , 降低了總訪存的次數, 因此會有效能的提升。r-8 的 kernel 每個蝶形對應 16 次 load/store、24 次複數浮點加法和 8 次浮點乘法。浮點乘法的存在使得計算代價有所上升, stage 由 \(\log_4N\) 進一步下降到 \(\log_8N\) ,但由於 N 日常並不會太大, r-4 到 r-8 的 stage 減少不算明顯,所以優化有限

我們再來看計算的開銷。減少計算的開銷通常有兩種辦法:減少多餘的運算、並行化。

以 r-4 演算法為例,kernel 部分的計算為:

  • radix_4_first_stage(src, dst, sec_num, butterfly_num)
  • radix_4_other_stage(src, dst, sec_num, butterfly_num)
    • for Sec_num
      • for butterfly_num
        • raidx_4_kernel

radix4_first_stage 的資料由於 k=0, 旋轉因子都為 1,可以省去這部分複數乘法運算,單獨優化。 radix4_other_stage 部分, 從第 2 個 stage 往後, butterfly_num = 4^(s-1) 都為 4 的倍數,而每個 butterfly 陣列讀取/儲存都是間隔的。可以對最裡層的迴圈做迴圈展開加向量化,實現 4 個或更多 butterfly 並行運算。迴圈展開和 SIMD 指令的使用不僅可以提高並行性, 也可以提升 cacheline 利用的效率,可以帶來較大的效能提升。 以 SM8150(armv8) 為例,r-4 的並行優化可以達到 r2 的 1.6x 的效能。

尺寸:1 * 2048(r2c) 環境:SM8150 大核

總之,對於 radix-2^n 的優化,選用合適的 radix 以減少多 stage 帶來的訪存開銷,並且利用單位復根性質以及並行化降低計算的開銷,可以帶來較大的效能提升。

radix-非 2 的冪次優化

當輸入長度 N = radix1^m1 * radix2^m2... 且 radix 都不為 2 的冪次時,如果使用 naive 的 O(n^2) 演算法, 效能就會急劇下降。 常見的解決辦法對原長補 0、使用 radix_N 演算法、特殊的 radix_N 演算法 (chirp-z transform)。補 0 至 2 的冪次方法對於大尺寸的輸入要增加很多運算量和儲存量, 而 chirp-z transform 是用卷積計算 DFT, 演算法過於複雜。因此對非 2 的冪次 radix-N 的優化也是必要的。

radix-N 計算流程和 radix-2 冪次一樣,我們同樣可以利用單位復根的週期性和對稱性,對 kernel 進行計算的簡化。 以 radix-5 為例,radix-5 的 DFT_kernel 為:
\(\left[\begin{array}{cccc} 1&1&1&1&1\\ 1 &\mathrm{W}_{\mathrm{5}}^{1} & \mathrm{W}_{\mathrm{5}}^{2} & \mathrm{W}_{\mathrm{5}}^{-2} & \mathrm{W}_{\mathrm{5}}^{-1} \\ 1 &\mathrm{W}_{\mathrm{5}}^{2} & \mathrm{W}_{\mathrm{5}}^{-1} & \mathrm{W}_{\mathrm{5}}^{1} & \mathrm{W}_{\mathrm{5}}^{-2} \\ 1 &\mathrm{W}_{\mathrm{5}}^{-2} & \mathrm{W}_{\mathrm{5}}^{1} & \mathrm{W}_{\mathrm{5}}^{-1} & \mathrm{W}_{\mathrm{5}}^{2} \\ 1 &\mathrm{W}_{\mathrm{5}}^{-1} & \mathrm{W}_{\mathrm{5}}^{-2} & \mathrm{W}_{\mathrm{5}}^{2} & \mathrm{W}_{\mathrm{5}}^{1} \\ \end{array}\right]\)

\(W_5^k\)\(W_{5}^{-k}\)在複平面上根據 x 軸對稱,有相同的實部和相反的虛部。根據這個性質。如下圖所示,對於每一個 stage,可以合併公共項 A,B,C,D,再根據公共項計算出該 stage 的輸出。

\(\begin{array}{l} A=\left(x_{1}+x_{4}\right) * W_{5}^{1} \cdot r+\left(x_{2}+x_{3}\right) * W_{5}^{2} \cdot r\\\end{array}\)

$B=(-j) * \left[\left(x_{1}-x_{4}\right) * W_{5}^{1} \cdot i+\left(x_{2}-x_{3}\right) * W_{5}^{2} \cdot i\right] $

$C=\left(x_{1}+x_{4}\right) * W_{5}^{2} \cdot r+\left(x_{2}+x_{3}\right) * W_{5}^{1} \cdot r$

$D=j * \left[\left(x_{1}-x_{4}\right) * W_{5}^{2} \cdot i-\left(x_{2}-x_{3}\right) * W_{5}^{1} \cdot i\right] $

\(\begin{array}{l} X(k)=x_{0}+\left(x_{1}+x_{4}\right)+\left(x_{2}+x_{3}\right)\\ \end{array}\)
\(\begin{array}{l} X(k+N/5)=x_{0}+\mathrm{A}-\mathrm{B}\\ X(k+2N/5)=x_{0}+\mathrm{C}+\mathrm{D}\\ X(k+3N/5)=x_{0}+C-D\\ X(k+4N/5)=x_{0}+\mathrm{A}+\mathrm{B}\\ \end{array}\)

這種演算法減少了很多重複的運算。同時,在 stage>=2 的時候,同樣對 butterfly 做迴圈展開加並行化,進一步減少計算的開銷。
radix-5 的優化思想可以外推至 radix-N。對於 radix_N 的每一個 stage, 計算流程為:

  • 取輸入
  • 乘以對應的轉換因子
  • 計算公共項, radix_N 有 N-1 個公共項
  • 執行並行化的 radix_N_kernel
  • 寫入輸出

其他優化

上述兩個章節描述的是 DFT_1D 的通用優化,在此基礎上還可以做更細緻的優化,可以參考本文引用的論文。

  • 對於全實數輸入的, 由於輸入的虛部為 0, 進行旋轉因子以及 radix_N_kernel 的複數運算時會有多餘的運算和多餘的儲存, 可以利用 split r2c 演算法, 視為長度為 N/2 的複數序列, 計算 DFT 結果並進行 split 操作得到 N 長實數序列的結果。
  • 對於 radix-2 的冪次演算法, 重新計算每個 stage 的輸入/輸出 stride 以取消第一級的位元翻轉可以進一步減少訪存的開銷。
  • 對於 radix-N 演算法, 在混合基框架下 N = radix1^m1 * radix2^m2, 合併較小的 radix 為大的 radix 以減少 stage。

DFT 延展演算法的原理及優化

DCT 和 FFT_conv 兩個典型的基於 DFT 延展的演算法,DFT_1D/2D 的優化可以很好的用在這類演算法中。

DCT

DCT 演算法(Discrete Cosine Transform, 離散餘弦變換)可以看作是 DFT 取其正弦分量並經過工業校正的演算法。DFT_1D 的計算公式為:

\(\begin{aligned} X[k] &=C(k) \sum_{n=0}^{N-1} x[n] \cos \left(\frac{(2 n+1) \pi k}{2 N}\right) \\ &C(k)=\sqrt{\frac{1}{n}} \\&k=1 \\ &C(k)=\sqrt{\frac{2}{n}} \\&k!=1 \\ \end{aligned}\)

該演算法 naive 實現是 O(n^2) 的,而我們將其轉換成 DFT_1D 演算法,可以將演算法複雜度降至 O(nlogn)。
基於 DFT 的 DCT 演算法流程為:

  • 對於 DCT 的輸入序列 x[n], 建立長為 2N 的輸入序列 y[n] 滿足 y[n] = x[n] + x[2N-n-1], 即做一個映象對稱。
  • 對輸入序列 y[n] 進行 DFT 運算,得到輸出序列 Y[K]。
  • 由 Y[K] 計算得到原輸入序列的輸出 X[K] 。

我們嘗試推導一下這個演算法:

$ {l}
y[n]=x[n]+x [2 N-1-n] $

\({l} Y[k]=\sum_{n=0}^{N-1} x[n]\cdot e^{-j \frac{2 \pi k n}{2 N}} +\sum_{n=N}^{2 N-1} x[2 N-1-n] \cdot e^{-j \frac{2 \pi k n}{2 N}}\)

\(=\sum_{n=0}^{N-1} x[n]\cdot e^{-j \frac{2 \pi k n}{2 N}} +\sum_{n=0}^{N-1} x[n] \cdot e^{-j \frac{2 \pi k (2N-1-n)}{2 N}}\)
\(=e^{-j \frac{2 \pi k }{2 N}} \cdot \sum_{n=0}^{N-1} x[n] (e^{-j \frac{2\pi}{2 N} kn} \cdot e^{-j \frac{\pi}{2 N}k}+e^{j \frac{2\pi}{2 N} kn} \cdot e^{j \frac{\pi}{2 N}k})\)
\(=e^{-j \frac{2 \pi k }{2 N}} \cdot \sum_{n=0}^{N-1} x[n] \cdot 2\cdot\cos(\frac{2n+1}{2N} k\pi)\)
\(=e^{-j \frac{2 \pi k }{2 N}} \cdot C(u) \cdot X[k]\)

對 y[n] 依照 DFT 公式展開,整理展開的兩項並提取公共項\(e^{-j \frac{2 \pi k }{2 N}}\), 根據尤拉公式和誘導函式,整理非公共項\((e^{-j \frac{2\pi}{2 N} kn} \cdot e^{-j \frac{\pi}{2 N}k}+e^{j \frac{2\pi}{2 N} kn} \cdot e^{j \frac{\pi}{2 N}k})\)。可以看出得到的結果正是 x[k] 和與 k 有關的係數的乘積。這樣就可以通過先計算\(Y[k]\)得到 x[n] 的 DCT 輸出\(X[k]\)

在理解演算法的基礎上,我們對 DFT_1D 的優化可以完整地應用到 DCT 上。DCT_2D 的計算過程是依次對行、列做 DCT_1D, 我們用多執行緒對 DCT_1D 進行並行,可以進一步優化演算法。

FFT_conv

Conv 是深度學習最常見的運算,計算 conv 常用的方法有 IMG2COL+GEMM, Winograd, FFT_conv。三種演算法都有各自的使用場景。

FFT_conv 的數學原理是時域中的迴圈卷積對應於其離散傅立葉變換的乘積。如下圖所示, f 和 g 的卷積等同於將 f 和 g 各自做傅立葉變幻 F,進行點乘並通過傅立葉逆變換計算後的結果。
\(f \underset{\text { Circ }}{*} g=\mathcal{F}^{-1}(\mathcal{F}(f) \cdot \mathcal{F}(g))\)

直觀的理論證明可下圖(來源)。

將卷積公式和離散傅立葉變換展開, 改變積分的順序並且替換變數, 可以證明結論。
注意這裡的卷積是迴圈卷積, 和我們深度學習中常用的線性卷積是有區別的。 利用迴圈卷積計算線性卷積的條件為迴圈卷積長度 L⩾| f |+| g |−1。 因此我們要對 Feature Map 和 Kernel 做 zero-padding,並從最終結果中取有效的線性計算結果。

FFT_conv 演算法的流程:

  • 將 Feature Map 和 Kernel 都 zero-pad 到同一個尺寸,進行 DFT 轉換。
  • 矩陣點乘
  • 將計算結果通過 IDFT 計算出結果。

該演算法將卷積轉換成點乘, 演算法複雜度是 O(nlogn), 小於卷積的 O(n^2), 在輸入的尺寸比較大時可以減少運算量,適用於大 kernel 的 conv 演算法。

深度學習計算中, Kernel 的尺寸要遠小於 Feature Map, 因此 FFT_conv 第一步的 zero-padding 會有很大的開銷,參考論文 2 裡提到可以通過對 Feature map 進行分塊,分塊後的 Feature Map 和 Kernel 需要 padding 到的尺寸較小,可以大幅減小這一部分的開銷。 優化後 fft_conv 的計算流程為:

  • 合理安排快取計算出合適的 tile 尺寸,對原圖進行分塊
  • 分塊後的小圖和 kernel 進行 zero-padding, 並進行 DFT 運算
  • 小圖矩陣點乘
  • 進行逆運算並組合成大圖。

同時我們可以觀察到,FFT_conv 的核心計算模組還是針對小圖的 DFT 運算, 因此我們可以將前一章節對 DFT 的優化代入此處,輔以多執行緒,進一步提升 FFT_Conv 的計算效率。

參考資料

  1. 陳暾,李志豪,賈海鵬,張雲泉。基於 ARMV8 平臺的多維 FFT 實現與優化研究
  2. Qinglin Wang,Dongsheng Li. Optimizing FFT-Based Convolutionon ARMv8 Multi-core CPUs
  3. Aleksandar Zlateski, Zhen Jia, Kai Li, Fredo Durand. FFT Convolutions are Faster than Winograd onModern CPUs, Here’s Why

相關文章