快速傅立葉變換(FFT)隨筆

Roor發表於2021-12-02

終於學會了FFT,水一篇隨筆記錄一下

前置知識網上一大堆,這裡就不多贅述了,直接切入正題

 

01 介紹FFT

這裡僅指出FFT在競賽中的一般應用,即優化多項式乘法

一般情況下,計算兩個規模為$n$的多項式相乘的結果,複雜度為$O(n^2)$,但是神奇的FFT可以將其優化至$O(nlogn)$

FFT的過程一般為:

多項式的係數表示$\longrightarrow$多項式的點值表示$\longrightarrow$多項式的係數表示

網上對每一步的叫法都有一定出入,這裡稱第一步變換為快速傅立葉變換,第二步為快速傅立葉逆變換

 

02快速傅立葉變換

先指出,接下來的每個$n$都是$2$的整數次冪

首先我們有一個已知係數表達的$n$項的多項式

$A(x)=a_0+a_1x+a_2x^2+\dots+a_{n-1}x^{n-1}$

要確定其的點值表達$(y_0,y_1,y_2,\dots,y_{n-1})$,樸素的做法就是取$n$個不同值代進去,這麼做顯然是$O(n^2)$

下面介紹快速傅立葉變換的做法

首先將多項式按照奇偶分類

$A(x)=(a_0+a_2x^2+\dots+a_{n-2}x^{n-2})+(a_1x+a_3x^3+\dots+a_{n-1}x^{n-1})$

$A(x)=(a_0+a_2x^2+\dots+a_{n-2}x^{n-2})+x\cdot(a_1+a_3x^2+\dots+a_{n-1}x^{n-2})$

$A_1(x)=a_0+a_2x+\dots+a_{n-2}x^{\tfrac{n-2}{2}}$

$A_2(x)=a_1+a_3x+\dots+a_{n-1}x^{\tfrac{n-2}{2}}$

 不難發現

$A(x)=A_1(x^2)+xA_2(x^2)$

令$k<\frac{n}{2}$

將$\omega_{n}^k$代入得

$A(\omega_{n}^k)=A_1(\omega_{n}^{2k})+\omega_{n}^{k}A_2(\omega_{n}^{2k})$

$A(\omega_{n}^k)=A_1(\omega_{\tfrac{n}{2}}^{k})+\omega_{n}^{k}A_2(\omega_{\tfrac{n}{2}}^{k})$

將$\omega_{n}^{k+\frac{n}{2}}$代入得

$A(\omega_{n}^{k}+\tfrac{n}{2})=A_1(\omega_{n}^{2k+n})+\omega_{n}^{k+\tfrac{n}{2}}A_2(\omega_{n}^{2k+n})$

$A(\omega_{n}^{k}+\tfrac{n}{2})=A_1(\omega_{n}^{2k}\cdot\omega_{n}^{n})-\omega_{n}^{k}A_2(\omega_{n}^{2k}\cdot\omega_{n}^{n})$

$A(\omega_{n}^{k}+\tfrac{n}{2})=A_1(\omega_{n}^{2k})-\omega_{n}^{k}A_2(\omega_{n}^{2k})$

$A(\omega_{n}^k)=A_1(\omega_{\tfrac{n}{2}}^{k})-\omega_{n}^{k}A_2(\omega_{\tfrac{n}{2}}^{k})$

顯然的,這兩個式子只有常數項不同

當$k$取遍$[0,\frac{n}{2}-1]$中所有值時$k+\dfrac{n}{2}$也取遍$[\dfrac{n}{2},n-1]$中所有值

因此,問題的規模縮小了一半,我們只需要在$[0,\dfrac{n}{2}-1]$中列舉$k$,這樣就可以算出$A(\omega_{n}^i)\quad(i\in[0,n-1])$的所有值

如果我們已知$A_1(x),A_2(x)$在$\omega_{\tfrac{n}{2}}^0,\omega_{\tfrac{n}{2}}^1,\dots,\omega_{\tfrac{n}{2}}^{\tfrac{n}{2}-1}$的值,通過上面的兩個式子就可以在$O(n)$的時間內求出$A(x)$

而求$A_1(x),A_2(x)$正好是求$A(x)$的子問題,並且可以遞迴求解

 

03快速傅立葉逆變換

在上面我們將一個多項式的係數表示轉換成了點值表示,這裡我們要研究將一個多項式的點值表示轉換成係數表示

記$(a_0,a_1,\dots,a_{n-1})$是$A(x)$的係數向量,而我們已知$A(x)$的點值表達為$(A(x_0),A(x_1),\dots,A(x_{n-1}))$

設向量$(d_0,d_1,\dots,d_{n-1})$是以$(a_0,a_1,\dots,a_{n-1})$為係數向量,快速傅立葉變換求得的點值表示

構造一個多項式$F(x)=d_0+d_1x+d_2x^2+\dots+d_{n-1}x^{n-1}$

設$(c_0,c_1,\dots,c_{n-1})$是$F(x)$在$x=\omega_n^{-k}$時的點值表示,即$c_k=F(\omega_n^k)$,也就是$c_k=\sum_{i=0}^{n-1}d_i(\omega_n^{-k})^i$

我們知道$d_k=A(\omega_n^k)$,也就是$d_k=\sum_{j=0}^{n-1}a_j(\omega_n^k)^j$

聯立上面兩個和式得

$c_k=\sum_{i=0}^{n-1} [\sum_{j=0}^{n-1}a_j(\omega_n^i)^j] (\omega_n^{-k})^i$

$\quad \:=\sum_{i=0}^{n-1} \sum_{j=0}^{n-1}a_j(\omega_n^j)^i (\omega_n^{-k})^i$

$\quad \:=\sum_{j=0}^{n-1} a_j \sum_{i=0}^{n-1} (\omega_n^j \omega_n^{-k})^i$

$\quad \:=\sum_{j=0}^{n-1} a_j \sum_{i=0}^{n-1} (\omega_n^{j-k})^i$

我們分情況討論後面的一個和式$\sum_{i=0}^{n-1} (\omega_n^{j-k})^i$

$j \neq\ k$

那麼後面的一個和式就轉換為一個等比求和

$\sum_{i=0}^{n-1} (\omega_n^{j-k})^i=\frac{\omega_n^0 [1-(\omega_n^{j-k})^n]}{1-\omega_n^{j-k}}$

$\qquad \qquad \quad \: \: \:=\frac{1-(\omega_n^{j-k})^n}{1-\omega_n^{j-k}}$

$\qquad \qquad \quad \: \: \:=\frac{1-(\omega_n^n)^{j-k}}{1-\omega_n^{j-k}}$

$\qquad \qquad \quad \: \: \:=\frac{1-1^{j-k}}{1-\omega_n^{j-k}}$

$\qquad \qquad \quad \: \: \:=\frac{0}{1-\omega_n^{j-k}}$

$\qquad \qquad \quad \: \: \:=0$

$j = k$

那麼$\omega_n^{j-k} = 1$

$\sum_{i=0}^{n-1} (\omega_n^{j-k})^i = n$

由上面兩種情況,我們知道當且僅當$j = k$時,整個式子才有值,其餘情況都為$0$

所以有

$c_j=a_jn$

$a_j = \frac{c_j}{n}$

到這裡,我們就求出了$A(x)$的係數表達

從整個分析過程看,我們是將$A(x)$的點值表示$(A(x_0),A(x_1),\dots,A(x_{n-1}))$當作一個新的多項式$F(x)$的係數表示,再對$F(x)$做快速傅立葉變換得到$(c_0,c_1,\dots,c_{n-1})$,然後再除以$n$就得到$A(x)$的係數表示了。需要指出的是,快速傅立葉變換中$x=\omega_n^k$但是在逆變換中代入的是$\omega_n^{-k}$

 

04實現

學會了前面的方法,具體實現就不難了

對於求$C(x)=A(x) \cdot B(x)$

將$A(x)$和$B(x)$都轉化成點值表達,即$(a_0,a_1,\dots,a_{n-1})$和$(b_0,b_1,\dots,b_{n-1})$

對應相乘$(a_0b_0,a_1b_1,\dots,a_{n-1}b_{n-1})$,再將這一結果變換成$C(x)$的係數表達就完成了

貼一份C++的程式碼,這是洛谷上的FFT板子題P3803

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#define MAXN 4000006
using namespace std;
class complex
{
public:
    complex(){}
    complex(double a,double b)
    {
        this->a=a;
        this->b=b;
    }
    double a,b;
}a[MAXN],b[MAXN];
complex operator+ (complex x,complex y)
{
    return complex(x.a+y.a,x.b+y.b);
}
complex operator- (complex x,complex y)
{
    return complex(x.a-y.a,x.b-y.b);
}
complex operator* (complex x,complex y)
{
    return complex(x.a*y.a-x.b*y.b,x.a*y.b+x.b*y.a);
}
const double pi=acos(-1.0);
void FFT(int l,complex *arr,int f)
{
    if(l==1) return;
    int dl=l>>1;
    complex a1[dl],a2[dl];
    for(int i=0;i<l;i+=2)
    {
        a1[i>>1]=arr[i];
        a2[i>>1]=arr[i+1];
    }
    FFT(dl,a1,f);
    FFT(dl,a2,f);
    complex wn=complex(cos(2.0*pi/l),sin(2.0*pi/l)*f),w=complex(1.0,0.0);
    for(int i=0;i<dl;i++,w=w*wn)
    {
        arr[i]=a1[i]+w*a2[i];
        arr[i+dl]=a1[i]-w*a2[i];
    }
}
int n,m,N;
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;i++)
        scanf("%lf",&a[i].a);
    for(int i=0;i<=m;i++)
        scanf("%lf",&b[i].a);
    N=1;
    while(N<n+m+1) N<<=1;
    FFT(N,a,1);
    FFT(N,b,1);
    for(int i=0;i<N;i++)
        a[i]=a[i]*b[i];
    FFT(N,a,-1);
    for(int i=0;i<n+m+1;i++)
        printf("%d ",(int)(a[i].a/N+0.5));
    puts("");
    return 0;
}

閒著沒事幹,再貼一份Python的

import numpy as np

pi = np.arccos(-1.0)


def read():
    def get_numbers():
        try:
            read.s = input().split()
            read.s_len = len(read.s)
            if read.s_len == 0:
                get_numbers()
            read.cnt = 0
            return 1
        except:
            return 0

    if not hasattr(read, 'cnt'):
        if not get_numbers():
            return 0
    if read.cnt == read.s_len:
        if not get_numbers():
            return 0
    read.cnt += 1
    return eval(read.s[read.cnt - 1])


n = int(read())
m = int(read())


class Complex:
    # 複數類

    def __init__(self, a=0.0, b=0.0):
        self.a = a
        self.b = b

    def __add__(self, other):
        return Complex(self.a + other.a, self.b + other.b)

    def __sub__(self, other):
        return Complex(self.a - other.a, self.b - other.b)

    def __mul__(self, other):
        return Complex(self.a * other.a - self.b * other.b, self.a * other.b + self.b * other.a)


def fft(num, f, args):
    if num == 1:
        return
    div_num = num >> 1
    a1 = []
    a2 = []
    for i in range(0, num, 2):
        a1.append(args[i])
        a2.append(args[i + 1])
    fft(div_num, f, a1)
    fft(div_num, f, a2)
    wn = Complex(np.cos(2.0 * pi / num), np.sin(2.0 * pi / num) * f)
    w = Complex(1.0, 0.0)

    for i in range(0, div_num):
        args[i] = a1[i] + w * a2[i]
        args[i + div_num] = a1[i] - w * a2[i]
        w = w * wn


aa = []
bb = []
for j in range(0, n + 1):
    aa.append(Complex(float(read()), 0.0))
for j in range(0, m + 1):
    bb.append(Complex(float(read()), 0.0))

nn = 1
while nn < n + m + 1:
    nn <<= 1

for j in range(n + 1, nn):
    aa.append(Complex(0.0, 0.0))
for j in range(m + 1, nn):
    bb.append(Complex(0.0, 0.0))

fft(nn, 1, aa)
fft(nn, 1, bb)

for j in range(0, nn):
    aa[j] = aa[j] * bb[j]
fft(nn, -1, aa)

for j in range(0, n + m + 1):
    print(int(aa[j].a / nn + 0.5), end=' ')

無奈Python實在是太慢了……

 

05結語

總算是學會了快速傅立葉變換,某種程度上說是彌補了過去的某些遺憾吧。

這裡貼一張大佬的圖,解釋了FFT的思路

 

這裡也推薦一下大佬的部落格,以供參考

快速傅立葉變換(FFT)詳解 - 自為風月馬前卒 - 部落格園 (cnblogs.com)

一小時學會快速傅立葉變換(Fast Fourier Transform) - 知乎 (zhihu.com)

 

相關文章