快速傅立叶变换

FFT原理

对于 N-1 次多项式,将系数表达式

f(x) = \sum_{0 \le i < N} a_i x^i

和点值表达式

f(x) = \sum_{0 \le i < N} y_i \frac{ \prod_{j \neq i}(x-x_j) }{ \prod_{j \neq i}(x_i-x_j) }, \ y_i = \sum_{0 \le n < N} c_i x_i^n

之间互相转化。其中,点取值为N次单位根。

考虑 N=2^K,方便分治计算呀。

根据一系列推导,可以得知 DFT 计算是这样的:

\left(
\begin{array}{cccccc}
1 &1 &1 &1 &\cdots &1 \\
1 &\omega &\omega^2 &\omega^3 &\cdots &\omega^{n}\\
1 &\omega^2 &\omega^4 &\omega^6 &\cdots &\omega^{2n}\\
1 &\omega^3 &\omega^6 &\omega^9 &\cdots &\omega^{3n}\\
\vdots &\vdots &\vdots &\vdots &\ddots &\vdots\\
1 &\omega^{n}&\omega^{2n}&\omega^{3n}&\cdots&\omega^{n^2}
\end{array}
\right)
\left(
\begin{array}{c}
a_0\\a_1\\a_2\\a_3\\\vdots\\a_{n}
\end{array}
\right)
=
\left(
\begin{array}{c}
y_0\\y_1\\y_2\\y_3\\\vdots\\y_{n}
\end{array}
\right)

而IDFT的计算相当于在右边左乘它的逆矩阵,可以证明,对矩阵每一项取倒数并除以N,就是它的逆矩阵。

X_k = \sum_{0 \le n < N} x_n \omega^{nk}
x_n = \frac{1}{N} \sum_{0 \le k < N} X_k \omega^{-kn}

NTT原理

考虑上述 N次单位根群 \lbrace 1, \omega, \omega^2, \omega^3, \cdots \rbrace 到某个数论乘法群的同构。

我们令 P = 2^K * Q + 1,则寻找它的原根 g,有 2^K 阶循环群 \lbrace 1, g^Q, g^{2Q}, g^{3Q}, \cdots \rbrace,易知他们同构。

由于FFT涉及到的是浮点数运算,可能会增大误差,所以考虑使用整数。毕竟整数乘法更加精确。

在这里,利用 g^Q 替代 \omega 即可。

一般用NTT的题目中,会取 p = 998244353,因为 p = 119 * 2^{23} + 1。此时 g = 3

X_k = \sum_{0 \le n < N} x_n g^{Qnk} \ (mod\ P)
x_n = \frac{1}{N} \sum_{0 \le k < N} X_k g^{-Qkn} \ (mod\ P)

变换动机

多项式乘法

多项式 A(x) = \sum a_n x^n B(x) = \sum b_n x^n,求 A(x)B(x)

朴素的多项式乘法,计算 c_n = \sum a_k b_{n-k} ,时间复杂度约为 O(n^2)

由变换计算出的点值表示法,如果点是相同的,那么

(A \cdot B)(x_i) = A(x_i) \cdot B(x_i)

可以在 O(n) 的时间内得到,还原 (A \cdot B)为系数表示就完成了多项式乘法。

在实现FFT时,通常可以通过分治达到 O(n logn) 的时间复杂度,所以通过 FFT 的多项式乘法的总时间复杂度为 O(n logn)。完成了一次优化。

卷积

定义对于两个函数的运算
f(x) \circledast g(x) = \int_0^x f(u) * g(x-u) du
则上述 c_n 符合了卷积的定义。

考虑连续傅立叶变换 F(\omega) = \mathfrak{F}[f(t)] = \int_{-\infty}^{\infty}f(t)e^{-i\omega t}dt 下象函数和象原函数有这样的关系,

F(\omega) * G(\omega) = \mathfrak{F}[f(t) \circledast g(t)]

暗示DFT是具有循环卷积性质的。

由于卷积可以表达成为多项式乘法的项系数,所以可以利用上述变换达到目的。

变换细节

分治求点值

分治地求当 x = \omega^n 时多项式的值。

\begin{aligned} F(x) &= a_0 + a_1 x + a_2 x^2 + a_3 x^3 + a_4 x^4 + a_5 x^5 + a_6 x^6 + a_7 x^7 \\ &= (a_0 + a_2 x^2 + a_4 x^4 + a_6 x^6) + x (a_1 + a_3 x^2 + a_5 x^4 + a_7 x^6) \\ &= G(x^2) + x * H(x^2) \end{aligned}

把当前单位复根的平方分别以DFT的方式带入G函数和H函数求值

DFT(F(x))_k = DFT(G(x^2))_k + \omega^k * DFT(H(x^2))_k

将递归消成迭代

从分治的角度来说,会逐步拆分成如下序列:

\lbrace x_0, x_1, x_2, x_3, x_4, x_5, x_6, x_7 \rbrace
\lbrace x_0, x_2, x_4, x_6 \rbrace, \lbrace x_1, x_3, x_5, x_7 \rbrace
\lbrace x_0, x_4 \rbrace, \lbrace x_2, x_6 \rbrace, \lbrace x_1, x_5 \rbrace, \lbrace x_3, x_7 \rbrace
\lbrace x_0 \rbrace, \lbrace x_4 \rbrace, \lbrace x_2 \rbrace, \lbrace x_6 \rbrace, \lbrace x_1 \rbrace, \lbrace x_5 \rbrace, \lbrace x_3 \rbrace, \lbrace x_7 \rbrace

将下标化为二进制,会发现拆分后的序列下标,恰好为长度为3的二进制数的反转。容易证明这个结论的普遍性。

模板

/*******************************************
* 快速傅立叶变换:时间 O(nlogn)
* 多项式系数表示与点值表示转换
* - Fast Fourier Transform 虚数
* - Number Theory Transform 数论数
*******************************************/

const size_t max_len = 17;
#define _rev(i) (rev[i>>1]>>1)|((i&1)<<len-1)

#ifdef NTT // 整数运算
const int mod = 998244353, g = 3;
typedef long long lld;
struct mint { 整数取模运算+-*/ };
typedef mint num_t;

inline num_t create(int m)
{
   int k = (mod-1)/m;
   if (k < 0) k += mod-1;
   return pow(g, k);
}

#else // 浮点数运算
const double PI = acos(-1.0);
typedef complex<double> num_t;

inline num_t create(int m)
{
   return num_t(cos(2*PI/m), sin(2*PI/m));
}

#endif

class Fourier
{
public:
   int len, N;

private:
   num_t wmk[1 << max_len];
   int rev[1 << max_len];

   void dft(num_t a[], int DFT)
   {
      for (int i = 0; i < N; i++)
         if (i < rev[i])
            swap(a[i], a[rev[i]]);

      for (int m = 2; m <= N; m <<= 1)
      {
         int m2 = m >> 1;
         num_t wm = create(DFT * m);
         wmk[0] = 1;
         for (int j = 1; j < m2; j++)
            wmk[j] = wmk[j-1] * wm;
         num_t t, u;

         for (int k = 0; k < N; k += m)
         {
            for (int j = 0; j < m2; j++)
            {
               t = wmk[j] * a[k+j+m2];
               u = a[k+j];
               a[k+j] = u + t;
               a[k+j+m2] = u - t;
           }
         }
      }

      if (DFT == -1)
         for (int i = 0; i < N; i++)
            a[i] = a[i] / N;
   }

public:
   void init(int _len)
   {
      len = _len, N = 1 << _len;
      for (int i = 0; i < N; i++)
         rev[i] = _rev(i);
   }

   void DFT(num_t a[]) { dft(a, 1); }
   void IDFT(num_t a[]) { dft(a, -1); }
};

发表评论

电子邮件地址不会被公开。 必填项已用*标注