快速傅立叶变换
Aug 23, 2018 由 小羊
FFT原理
对于 N−1 次多项式,将系数表达式
f(x)=0≤i<N∑aixi
和点值表达式
f(x)=0≤i<N∑yi∏j=i(xi−xj)∏j=i(x−xj), yi=0≤n<N∑cixin
之间互相转化。其中,点取值为N次单位根。
考虑 N=2K,方便分治计算呀。
根据一系列推导,可以得知 DFT 计算是这样的:
⎝⎜⎜⎜⎜⎜⎜⎜⎛1111⋮11ωω2ω3⋮ωn1ω2ω4ω6⋮ω2n1ω3ω6ω9⋮ω3n⋯⋯⋯⋯⋱⋯1ωnω2nω3n⋮ωn2⎠⎟⎟⎟⎟⎟⎟⎟⎞⎝⎜⎜⎜⎜⎜⎜⎜⎛a0a1a2a3⋮an⎠⎟⎟⎟⎟⎟⎟⎟⎞=⎝⎜⎜⎜⎜⎜⎜⎜⎛y0y1y2y3⋮yn⎠⎟⎟⎟⎟⎟⎟⎟⎞
而IDFT的计算相当于在右边左乘它的逆矩阵,可以证明,对矩阵每一项取倒数并除以N,就是它的逆矩阵。
Xk=0≤n<N∑xnωnk
xn=N10≤k<N∑Xkω−kn
NTT原理
考虑上述 N次单位根群 {1,ω,ω2,ω3,⋯} 到某个数论乘法群的同构。
我们令 P=2K∗Q+1,则寻找它的原根 g,有 2K 阶循环群 {1,gQ,g2Q,g3Q,⋯},易知他们同构。
由于FFT涉及到的是浮点数运算,可能会增大误差,所以考虑使用整数。毕竟整数乘法更加精确。
在这里,利用 gQ 替代 ω 即可。
一般用NTT的题目中,会取 p=998244353,因为 p=119∗223+1。此时 g=3。
Xk=0≤n<N∑xngQnk (mod P)
xn=N10≤k<N∑Xkg−Qkn (mod P)
变换动机
多项式乘法
多项式 $A(x) = \sum a_n x^n $ 和 B(x)=∑bnxn,求 A(x)B(x)。
朴素的多项式乘法,计算 $c_n = \sum a_k b_{n-k} $,时间复杂度约为 O(n2)。
由变换计算出的点值表示法,如果点是相同的,那么
(A⋅B)(xi)=A(xi)⋅B(xi)
可以在 O(n) 的时间内得到,还原 (A⋅B)为系数表示就完成了多项式乘法。
在实现FFT时,通常可以通过分治达到 O(nlogn) 的时间复杂度,所以通过 FFT 的多项式乘法的总时间复杂度为 O(nlogn)。完成了一次优化。
卷积
定义对于两个函数的运算
f(x)⊛g(x)=∫0xf(u)∗g(x−u)du
则上述 cn 符合了卷积的定义。
考虑连续傅立叶变换 F(ω)=F[f(t)]=∫−∞∞f(t)e−iωtdt 下象函数和象原函数有这样的关系,
F(ω)∗G(ω)=F[f(t)⊛g(t)]
暗示DFT是具有循环卷积性质的。
由于卷积可以表达成为多项式乘法的项系数,所以可以利用上述变换达到目的。
变换细节
分治求点值
分治地求当 x=ωn 时多项式的值。
F(x)=a0+a1x+a2x2+a3x3+a4x4+a5x5+a6x6+a7x7=(a0+a2x2+a4x4+a6x6)+x(a1+a3x2+a5x4+a7x6)=G(x2)+x∗H(x2)
把当前单位复根的平方分别以DFT的方式带入G函数和H函数求值
DFT(F(x))k=DFT(G(x2))k+ωk∗DFT(H(x2))k
将递归消成迭代
从分治的角度来说,会逐步拆分成如下序列:
{x0,x1,x2,x3,x4,x5,x6,x7}
{x0,x2,x4,x6},{x1,x3,x5,x7}
{x0,x4},{x2,x6},{x1,x5},{x3,x7}
{x0},{x4},{x2},{x6},{x1},{x5},{x3},{x7}
将下标化为二进制,会发现拆分后的序列下标,恰好为长度为3的二进制数的反转。容易证明这个结论的普遍性。
模板
const size_t max_len = 17;
#define _rev(i) (rev[i>>1]>>1)|((i&1)<<len-1)
#ifdef NTT
const int mod = 998244353, g = 3;
typedef long long lld;
struct mint { 整数取模运算+-*/ };
typedef mint num_t;
inline num_t create(int m)
{
int k = (mod-1)/m;
if (k < 0) k += mod-1;
return pow(g, k);
}
#else
const double PI = acos(-1.0);
typedef complex<double> num_t;
inline num_t create(int m)
{
return num_t(cos(2*PI/m), sin(2*PI/m));
}
#endif
class Fourier
{
public:
int len, N;
private:
num_t wmk[1 << max_len];
int rev[1 << max_len];
void dft(num_t a[], int DFT)
{
for (int i = 0; i < N; i++)
if (i < rev[i])
swap(a[i], a[rev[i]]);
for (int m = 2; m <= N; m <<= 1)
{
int m2 = m >> 1;
num_t wm = create(DFT * m);
wmk[0] = 1;
for (int j = 1; j < m2; j++)
wmk[j] = wmk[j-1] * wm;
num_t t, u;
for (int k = 0; k < N; k += m)
{
for (int j = 0; j < m2; j++)
{
t = wmk[j] * a[k+j+m2];
u = a[k+j];
a[k+j] = u + t;
a[k+j+m2] = u - t;
}
}
}
if (DFT == -1)
for (int i = 0; i < N; i++)
a[i] = a[i] / N;
}
public:
void init(int _len)
{
len = _len, N = 1 << _len;
for (int i = 0; i < N; i++)
rev[i] = _rev(i);
}
void DFT(num_t a[]) { dft(a, 1); }
void IDFT(num_t a[]) { dft(a, -1); }
};