快速傅立葉變換
快速傅立葉變換(Fast Fourier Transform, FTT)在ACM/OI中最主要的應用是計算多項式乘法。
多項式的係數表示和點值表示
假設\(f(x)\)為\(x\)的\(n\)階多項式,則其可以表示為:
這裡的\(n+1\)個係數\(\{a_0,a_1,\cdots,a_n\}\)就稱為多項式\(f(x)\)的係數表示。
另一方面,我們也可以把\(f(x)\)看成是一個關於\(x\)的函式,我們可以取\(n+1\)個不同的\(x_i\),用\(\{(x_0,f(x_0)),(x_1,f(x_1)),\cdots(x_n,f(x_n))\}\)這\(n+1\)個數值對來唯一確定\(f(x)\),這種表示形式就稱為多項式\(f(x)\)的點值表示。
點值表示與多項式乘法的關係
假設我們現在要求的是\(F(x)=f(x)\cdot g(x)\),如果我們已知\(f(x)\)和\(g(x)\)的點值表示,那麼我們可以非常容易地得到\(F(x)\)的點值表示為
注意這裡的\(n\)實際上要取到\(f(x)\)和\(g(x)\)的階數之和。
現在的關鍵問題是,如何快速將這一點值表示轉換為係數表示。
FFT的實現
為了解決這一問題,我們首先考慮其逆問題,也即:如何從係數表示快速計算點值表示。
FFT
暴力計算\(n\)對點值的總時間複雜度為\(O(n^2)\)。如何優化呢?我們希望我們選擇的\(n\)個\(x_i\)之間存在一定的關係,使得我們可以複用\(x_i^k\)的計算結果。那麼,應該如何選擇呢?
前人的經驗告訴我們,可以選擇單位復根\(\omega_n^i\)。它有三個重要的性質:
利用上述這三個性質,我們可以實現計算過程的簡化。
不妨考慮一個最高階為7階的多項式
可以把奇偶項分別處理
從而
這時把單位復根\(\omega_n^k\)(\(k<n/2\))代入,可以得到
而另一方面,代入\(\omega_n^{k+n/2}\)可以得到
因此,我們只要求得\(\text{DFT}(G(\omega_{n/2}^k))\)和\(\text{DFT}(H(\omega_{n/2}^k))\),就可以同時求得\(\text{DFT}(f(\omega_n^k))\)和\(\text{DFT}(f(\omega_n^{k+n/2}))\),這樣就把問題規模縮小了一半。
使用同樣的方法對\(\text{DFT}(G(\omega_{n/2}^k))\)和\(\text{DFT}(H(\omega_{n/2}^k))\)進行遞迴求解,我們有
可知總的時間複雜度為\(O(n\log n)\)。
在這一過程中,我們預設\(n/2\)總是整數,因此我們需要\(n=2^k\)。所以在計算之前,我們要先對係數補0,使得總的項數變為2的冪次。
逆FFT
將FFT的運算過程看做一個矩陣乘法,逆FFT,也即從點值表示求取係數表示的過程,可以視為左乘逆矩陣。在點值表示的點選取為\(\omega_n^k\)時,FFT矩陣\(\mathbb{A}(\omega_n^k)\)的逆矩陣恰好為\(\frac{1}{n}\mathbb{A}(\omega_n^{-k})\),因此可以複用FFT的計算過程,只需要加上一個標誌變數來表示當前是在進行FFT還是IFFT。
模板題:洛谷 P3803 - 多項式乘法(FFT)
下面給出了本題的遞迴實現。
Code(C++)
#include <cmath>
#include <complex>
#include <iostream>
#define MAXN (1 << 22)
using namespace std;
typedef complex cd;
const cd I{0, 1};
cd tmp[MAXN], a[MAXN], b[MAXN];
void fft(cd *f, int n, int rev) {
if (n == 1) return;
for (int i = 0; i < n; ++i) tmp[i] = f[i];
for (int i = 0; i < n; ++i) {
if (i & 1) f[n / 2 + i / 2] = tmp[i];
else
f[i / 2] = tmp[i];
}
cd *g = f, *h = f + n / 2;
fft(g, n / 2, rev), fft(h, n / 2, rev);
cd omega = exp(I * (2 * M_PI / n * rev)), now = 1;
for (int k = 0; k < n / 2; ++k) {
tmp[k] = g[k] + now * h[k];
tmp[k + n / 2] = g[k] - now * h[k];
now *= omega;
}
for (int i = 0; i < n; ++i) f[i] = tmp[i];
}
int main() {
int n, m;
cin >> n >> m;
int k = 1 << (32 - __builtin_clz(n + m + 1));
for (int i = 0; i <= n; ++i) cin >> a[i];
for (int j = 0; j <= m; ++j) cin >> b[j];
fft(a, k, 1);
fft(b, k, 1);
for (int i = 0; i < k; ++i) a[i] *= b[i];
fft(a, k, -1);
for (int i = 0; i < k; ++i) a[i] /= k;
for (int i = 0; i < n + m + 1; ++i) cout << (int)round(a[i].real()) << " ";
}
上述遞迴方法的常數較大,不能通過洛谷P3803的最後兩個測試點。
為了改寫非遞迴方法,我們引入蝴蝶變換的概念。
蝴蝶變換
繼續使用前面的例子,經過第一步分治,將原來的係數分為兩組:
繼續進行第二步分治,得到四組係數:
最後一步分治,得到八組係數:
所謂蝴蝶變換,指的就是從\({a_0,a_1,\cdots,a_{n-1}}\)這一原始係數序列,變換得到最後一步分治後的係數序列。
觀察後可以發現,在蝴蝶變換的最終結果中,係數下標的二進位制表示恰好是其所在位置二進位制表示的逆序,因此,可以利用這一規律來求取蝴蝶變換的結果。
直接利用規律來計算的複雜度是\(O(n\log n)\),如果從小到大遞推實現,複雜度則為\(O(n)\)。
FFT的非遞迴實現
下面給出了洛谷P3803的非遞迴實現。
Code(C++)
#include <cmath>
#include <complex>
#include <iostream>
#define MAXN (1 << 22)
using namespace std;
typedef complex cd;
const cd I{0, 1};
cd a[MAXN], b[MAXN];
void change(cd *f, int n) {
int i, j, k;
for (int i = 1, j = n / 2; i < n - 1; i++) {
if (i < j) swap(f[i], f[j]);
k = n / 2;
while (j >= k) {
j = j - k;
k = k / 2;
}
if (j < k) j += k;
}
}
void fft(cd *f, int n, int rev) {
change(f, n);
for (int len = 2; len <= n; len <<= 1) {
cd omega = exp(I * (2 * M_PI / len * rev));
for (int j = 0; j < n; j += len) {
cd now = 1;
for (int k = j; k < j + len / 2; ++k) {
cd g = f[k], h = now * f[k + len / 2];
f[k] = g + h, f[k + len / 2] = g - h;
now *= omega;
}
}
}
if (rev == -1)
for (int i = 0; i < n; ++i) f[i] /= n;
}
int main() {
int n, m;
cin >> n >> m;
int k = 1 << (32 - __builtin_clz(n + m + 1));
for (int i = 0; i <= n; ++i) cin >> a[i];
for (int j = 0; j <= m; ++j) cin >> b[j];
fft(a, k, 1);
fft(b, k, 1);
for (int i = 0; i < k; ++i) a[i] *= b[i];
fft(a, k, -1);
for (int i = 0; i < n + m + 1; ++i) cout << (int)round(a[i].real()) << " ";
}
學習資源
Matters Computational
- 第二十一章 快速傅立葉變換
練習題
裸FFT並不可怕,本身FFT的碼量並不算大,背一背也不是多大的事,關鍵是如何看出一道題目是FFT。
SPOJ - ADAMATCH
如果暴力列舉子串,時間複雜度為\(O(|r|^2)\),顯然不行。如何降低複雜度呢?
提示一
首先考慮字母'A'
。不妨把字串為'A'
的位置設為\(1\),其餘位置設為\(0\)。看起來似乎可以進行多項式乘法,但乘法的結果似乎沒有明顯的意義。
提示二
如果把r
串逆序呢?看看此時乘積的每一項有怎樣的含義。
參考程式碼(C++)
#include <cmath>
#include <complex>
#include <cstring>
#include <iostream>
#include <vector>
#define MAXN (1 << 22)
using namespace std;
typedef complex<double> cd;
const cd I{0, 1};
cd a[MAXN], b[MAXN];
void change(cd *f, int n) {
for (int i = 1, j = n / 2; i < n - 1; i++) {
if (i < j) swap(f[i], f[j]);
int k = n / 2;
while (j >= k) {
j = j - k;
k = k / 2;
}
if (j < k) j += k;
}
}
void fft(cd *f, int n, int rev) {
change(f, n);
for (int len = 2; len <= n; len <<= 1) {
cd omega = exp(I * (2 * M_PI / len * rev));
for (int j = 0; j < n; j += len) {
cd now = 1;
for (int k = j; k < j + len / 2; ++k) {
cd g = f[k], h = now * f[k + len / 2];
f[k] = g + h, f[k + len / 2] = g - h;
now *= omega;
}
}
}
if (rev == -1)
for (int i = 0; i < n; ++i) f[i] /= n;
}
int main() {
string s, r;
cin >> s >> r;
int n = s.size(), m = r.size();
int k = 1 << (32 - __builtin_clz(n + m + 1));
vector<int> cnt(k);
for (char c : "ACGT") {
memset(a, 0, sizeof(a));
memset(b, 0, sizeof(b));
for (int i = 0; i < n; ++i) a[i] = s[i] == c;
for (int i = 0; i < m; ++i) b[i] = r[m - i - 1] == c;
fft(a, k, 1);
fft(b, k, 1);
for (int i = 0; i < k; ++i) a[i] *= b[i];
fft(a, k, -1);
for (int i = 0; i < k; ++i) cnt[i] += (int)round(a[i].real());
}
int ans = m;
for (int i = m - 1; i < n; ++i) ans = min(ans, m - cnt[i]);
cout << ans;
}
SPOJ - TSUM
如果暴力列舉,時間複雜度為\(O(n^3)\),顯然不行。如何降低複雜度呢?
提示一
加法可以變為多項式的乘法。
提示二
如何去除包含重複元素的項?
參考程式碼(C++)
#include <cmath>
#include <complex>
#include <iostream>
#include <vector>
#define MAXN 131072
#define OFFSET 20000
using namespace std;
typedef complex<double> cd;
const cd I{0, 1};
void change(vector<cd> &f, int n) {
for (int i = 1, j = n / 2; i < n - 1; i++) {
if (i < j) swap(f[i], f[j]);
int k = n / 2;
while (j >= k) {
j = j - k;
k = k / 2;
}
if (j < k) j += k;
}
}
void fft(vector<cd> &f, int n, int rev) {
change(f, n);
for (int len = 2; len <= n; len <<= 1) {
cd omega = exp(I * (2 * M_PI / len * rev));
for (int j = 0; j < n; j += len) {
cd now = 1;
for (int k = j; k < j + len / 2; ++k) {
cd g = f[k], h = now * f[k + len / 2];
f[k] = g + h, f[k + len / 2] = g - h;
now *= omega;
}
}
}
if (rev == -1)
for (int i = 0; i < n; ++i) f[i] /= n;
}
int main() {
int n;
cin >> n;
vector<cd> a(MAXN), a2(MAXN);
vector<int> a3(MAXN);
for (int i = 0; i < n; ++i) {
int m;
cin >> m;
a[m + OFFSET] = cd{1, 0};
a2[(m + OFFSET) << 1] = cd{1, 0};
a3[(m + OFFSET) * 3] = 1;
}
vector<cd> tot(a), b(a);
fft(tot, MAXN, 1);
fft(b, MAXN, 1);
fft(a2, MAXN, 1);
for (int i = 0; i < MAXN; ++i) tot[i] *= b[i] * b[i], a2[i] *= b[i];
fft(tot, MAXN, -1);
fft(a2, MAXN, -1);
for (int i = 0; i < MAXN; ++i) {
int cnt1 = round(tot[i].real()); // ABC, with permutation
int cnt2 = round(a2[i].real()); // AAB, no permutation
int cnt3 = a3[i]; // AAA
int cnt = (cnt1 - cnt2 * 3 + cnt3 * 2) / 6;
if (cnt > 0) cout << i - OFFSET * 3 << " : " << cnt << endl;
}
}