🌑

小羊儿的心情天空

快速傅立叶变换

Aug 23, 2018 由 小羊

FFT原理

对于 N1N-1 次多项式,将系数表达式

f(x)=0i<Naixif(x) = \sum_{0 \le i < N} a_i x^i

和点值表达式

f(x)=0i<Nyiji(xxj)ji(xixj), yi=0n<Ncixinf(x) = \sum_{0 \le i < N} y_i \frac{ \prod_{j \neq i}(x-x_j) }{ \prod_{j \neq i}(x_i-x_j) }, \ y_i = \sum_{0 \le n < N} c_i x_i^n

之间互相转化。其中,点取值为N次单位根。

考虑 N=2KN=2^K,方便分治计算呀。

根据一系列推导,可以得知 DFT 计算是这样的:

(111111ωω2ω3ωn1ω2ω4ω6ω2n1ω3ω6ω9ω3n1ωnω2nω3nωn2)(a0a1a2a3an)=(y0y1y2y3yn)\left( \begin{array}{cccccc} 1 &1 &1 &1 &\cdots &1 \\ 1 &\omega &\omega^2 &\omega^3 &\cdots &\omega^{n}\\ 1 &\omega^2 &\omega^4 &\omega^6 &\cdots &\omega^{2n}\\ 1 &\omega^3 &\omega^6 &\omega^9 &\cdots &\omega^{3n}\\ \vdots &\vdots &\vdots &\vdots &\ddots &\vdots\\ 1 &\omega^{n}&\omega^{2n}&\omega^{3n}&\cdots&\omega^{n^2} \end{array} \right) \left( \begin{array}{c} a_0\\a_1\\a_2\\a_3\\\vdots\\a_{n} \end{array} \right) = \left( \begin{array}{c} y_0\\y_1\\y_2\\y_3\\\vdots\\y_{n} \end{array} \right)

而IDFT的计算相当于在右边左乘它的逆矩阵,可以证明,对矩阵每一项取倒数并除以N,就是它的逆矩阵。

Xk=0n<NxnωnkX_k = \sum_{0 \le n < N} x_n \omega^{nk}

xn=1N0k<NXkωknx_n = \frac{1}{N} \sum_{0 \le k < N} X_k \omega^{-kn}

NTT原理

考虑上述 N次单位根群 {1,ω,ω2,ω3,}\lbrace 1, \omega, \omega^2, \omega^3, \cdots \rbrace 到某个数论乘法群的同构。

我们令 P=2KQ+1P = 2^K * Q + 1,则寻找它的原根 gg,有 2K2^K 阶循环群 {1,gQ,g2Q,g3Q,}\lbrace 1, g^Q, g^{2Q}, g^{3Q}, \cdots \rbrace,易知他们同构。

由于FFT涉及到的是浮点数运算,可能会增大误差,所以考虑使用整数。毕竟整数乘法更加精确。

在这里,利用 gQg^Q 替代 ω\omega 即可。

一般用NTT的题目中,会取 p=998244353p = 998244353,因为 p=119223+1p = 119 * 2^{23} + 1。此时 g=3g = 3

Xk=0n<NxngQnk (mod P)X_k = \sum_{0 \le n < N} x_n g^{Qnk} \ (mod\ P)

xn=1N0k<NXkgQkn (mod P)x_n = \frac{1}{N} \sum_{0 \le k < N} X_k g^{-Qkn} \ (mod\ P)

变换动机

多项式乘法

多项式 $A(x) = \sum a_n x^n $ 和 B(x)=bnxnB(x) = \sum b_n x^n,求 A(x)B(x)A(x)B(x)

朴素的多项式乘法,计算 $c_n = \sum a_k b_{n-k} $,时间复杂度约为 O(n2)O(n^2)

由变换计算出的点值表示法,如果点是相同的,那么

(AB)(xi)=A(xi)B(xi)(A \cdot B)(x_i) = A(x_i) \cdot B(x_i)

可以在 O(n)O(n) 的时间内得到,还原 (AB)(A \cdot B)为系数表示就完成了多项式乘法。

在实现FFT时,通常可以通过分治达到 O(nlogn)O(n logn) 的时间复杂度,所以通过 FFT 的多项式乘法的总时间复杂度为 O(nlogn)O(n logn)。完成了一次优化。

卷积

定义对于两个函数的运算

f(x)g(x)=0xf(u)g(xu)duf(x) \circledast g(x) = \int_0^x f(u) * g(x-u) du

则上述 cnc_n 符合了卷积的定义。

考虑连续傅立叶变换 F(ω)=F[f(t)]=f(t)eiωtdtF(\omega) = \mathfrak{F}[f(t)] = \int_{-\infty}^{\infty}f(t)e^{-i\omega t}dt 下象函数和象原函数有这样的关系,

F(ω)G(ω)=F[f(t)g(t)]F(\omega) * G(\omega) = \mathfrak{F}[f(t) \circledast g(t)]

暗示DFT是具有循环卷积性质的。

由于卷积可以表达成为多项式乘法的项系数,所以可以利用上述变换达到目的。

变换细节

分治求点值

分治地求当 x=ωnx = \omega^n 时多项式的值。

F(x)=a0+a1x+a2x2+a3x3+a4x4+a5x5+a6x6+a7x7=(a0+a2x2+a4x4+a6x6)+x(a1+a3x2+a5x4+a7x6)=G(x2)+xH(x2)\begin{aligned} F(x) &= a_0 + a_1 x + a_2 x^2 + a_3 x^3 + a_4 x^4 + a_5 x^5 + a_6 x^6 + a_7 x^7 \\\\ &= (a_0 + a_2 x^2 + a_4 x^4 + a_6 x^6) + x (a_1 + a_3 x^2 + a_5 x^4 + a_7 x^6) \\\\ &= G(x^2) + x * H(x^2) \end{aligned}

把当前单位复根的平方分别以DFT的方式带入G函数和H函数求值

DFT(F(x))k=DFT(G(x2))k+ωkDFT(H(x2))kDFT(F(x))_k = DFT(G(x^2))_k + \omega^k * DFT(H(x^2))_k

将递归消成迭代

从分治的角度来说,会逐步拆分成如下序列:

{x0,x1,x2,x3,x4,x5,x6,x7}\lbrace x_0, x_1, x_2, x_3, x_4, x_5, x_6, x_7 \rbrace
{x0,x2,x4,x6},{x1,x3,x5,x7}\lbrace x_0, x_2, x_4, x_6 \rbrace, \lbrace x_1, x_3, x_5, x_7 \rbrace
{x0,x4},{x2,x6},{x1,x5},{x3,x7}\lbrace x_0, x_4 \rbrace, \lbrace x_2, x_6 \rbrace, \lbrace x_1, x_5 \rbrace, \lbrace x_3, x_7 \rbrace
{x0},{x4},{x2},{x6},{x1},{x5},{x3},{x7}\lbrace x_0 \rbrace, \lbrace x_4 \rbrace, \lbrace x_2 \rbrace, \lbrace x_6 \rbrace, \lbrace x_1 \rbrace, \lbrace x_5 \rbrace, \lbrace x_3 \rbrace, \lbrace x_7 \rbrace

将下标化为二进制,会发现拆分后的序列下标,恰好为长度为3的二进制数的反转。容易证明这个结论的普遍性。

模板

/*******************************************
* 快速傅立叶变换:时间 O(nlogn)
* 多项式系数表示与点值表示转换
* - Fast Fourier Transform 虚数
* - Number Theory Transform 数论数
*******************************************/

const size_t max_len = 17;
#define _rev(i) (rev[i>>1]>>1)|((i&1)<<len-1)

#ifdef NTT // 整数运算
const int mod = 998244353, g = 3;
typedef long long lld;
struct mint { 整数取模运算+-*/ };
typedef mint num_t;

inline num_t create(int m)
{
   int k = (mod-1)/m;
   if (k < 0) k += mod-1;
   return pow(g, k);
}

#else // 浮点数运算
const double PI = acos(-1.0);
typedef complex<double> num_t;

inline num_t create(int m)
{
   return num_t(cos(2*PI/m), sin(2*PI/m));
}

#endif

class Fourier
{
public:
   int len, N;

private:
   num_t wmk[1 << max_len];
   int rev[1 << max_len];

   void dft(num_t a[], int DFT)
   {
      for (int i = 0; i < N; i++)
         if (i < rev[i])
            swap(a[i], a[rev[i]]);

      for (int m = 2; m <= N; m <<= 1)
      {
         int m2 = m >> 1;
         num_t wm = create(DFT * m);
         wmk[0] = 1;
         for (int j = 1; j < m2; j++)
            wmk[j] = wmk[j-1] * wm;
         num_t t, u;

         for (int k = 0; k < N; k += m)
         {
            for (int j = 0; j < m2; j++)
            {
               t = wmk[j] * a[k+j+m2];
               u = a[k+j];
               a[k+j] = u + t;
               a[k+j+m2] = u - t;
           }
         }
      }

      if (DFT == -1)
         for (int i = 0; i < N; i++)
            a[i] = a[i] / N;
   }

public:
   void init(int _len)
   {
      len = _len, N = 1 << _len;
      for (int i = 0; i < N; i++)
         rev[i] = _rev(i);
   }

   void DFT(num_t a[]) { dft(a, 1); }
   void IDFT(num_t a[]) { dft(a, -1); }
};