快速沃尔什变换

新的卷积

卷积的东西都很好玩呀。有时间一定好好学泛函.jpg

前面提到了,离散傅里叶变换的多项式卷积是对于每个 k 获得 i+j=k\sum a_i \times a_j,而在数论中比较多的狄利克雷卷积是对于每个 k 获得 i \times j = k\sum a_i \times a_j。在实际题目中,我们可能会遇到这样的一类卷积:

对于每个 k,获得 i \otimes j = k\sum a_i \times a_j。其中 \otimes 是任意按位逻辑运算,如 \&, |, \oplus 等。

很显然,我们如果按照最朴素的枚举方法:先枚举 i,再获得可行的 k,复杂度是 O(n^2) 朝上走的。如果 n \ge 10^5 不就完蛋了吗。按照这个架势,我们应该是要找复杂度为 O(n\log n)O(n \log \log n)O(n \log^2 n) 的算法。

很幸运的是,我们可以通过分治来完成这样的一件事情。当然了,由于位运算的值域是 2^n 级别的,我们也要像 FFT 那样将数列长度补齐。

快速沃尔什变换

注意到,位运算是有位独立性的。每一位之间并不会互相影响。那么,我们每次只考虑某一位?

我们令 A_{xy} = (A_{0y}, A_{1y})B_{xy} = (B_{0y}, B_{1y}),行向量分块,其中 y 是任意01序列,那么我们想要
C_{xy} = A_{xy} * B_{xy} = \left(\sum_{a \otimes b = 0} A_{ay} * B_{by}\ , \sum_{a \otimes b = 1} A_{ay} * B_{by} \right)
由于逻辑运算是确定的,我们可以很容易的计算出 \otimes 的真值表,并且后续又变成了规模更小的卷积,就是一个分治的过程啦。

假如构造了一个变换 \mathfrak{T}(A),满足 \mathfrak{T} (C) = \mathfrak{T} (A) \times \mathfrak{T} (B),并且可以很短时间内求得 \mathfrak{T}(A)\mathfrak{T}^{-1}(A),就可以像FFT一样卷积了。

Fast Walsh-Hadamard Transform,这个神奇的变换,我们分别称变换与逆变换为 FWT(A)UFWT(A),那么

对于与运算 \& 来说,有

FWT(A) = \left(\ FWT(A_0) + FWT(A_1)\ ,\ FWT(A_1)\ \right)
UFWT(A) = \left(\ UFWT(A_0) – UFWT(A_1)\ ,\ UFWT(A_1)\ \right)

并且

FWT(A)[i] = \sum_{i \& j = i} A[j]

void FWT(lld X[], int l, int r, int dft = 1)
{
    if (l == r) return;
    int m = (l+r)>>1;
    FWT(X, l, m, dft);
    FWT(X, m+1, r, dft);
    for (int i = 0; i <= m-l; i++)
        X[l+i] += X[m+1+i] * dft;
}

对于或运算 | 来说,有

FWT(A) = \left(\ FWT(A_0)\ ,\ FWT(A_0) + FWT(A_1)\ \right)
UFWT(A) = \left(\ UFWT(A_0)\ ,\ UFWT(A_1) – UFWT(A_0)\ \right)

并且

FWT(A)[i] = \sum_{i | j = i} A[j]

void FWT(lld X[], int l, int r, int dft = 1)
{
    if (l == r) return;
    int m = (l+r)>>1;
    FWT(X, l, m, dft);
    FWT(X, m+1, r, dft);
    for (int i = 0; i <= m-l; i++)
        X[m+1+i] += X[l+i] * dft;
}

异或运算 \oplus
FWT(A) = \left(\ FWT(A_0) + FWT(A_1)\ ,\ FWT(A_0) – FWT(A_1)\ \right)
UFWT(A) = \left(\ UFWT(\frac{A_0 + A_1}2)\ ,\ UFWT(\frac{A_0-A_1}{2})\ \right)

暂时没有找到比较明显的规律

void FWT(lld X[], int l, int r, int dft)
{
    if (l == r) return;
    int m = (l+r)>>1;
    FWT(X, l, m, dft);
    FWT(X, m+1, r, dft);

    for (int i = 0; i <= m-l; i++)
    {
        lld x = X[l+i], y = X[m+1+i];
        X[l+i] = dft == 1 ? x+y : (x+y)/2;
        X[m+1+i] = dft == 1 ? x-y : (x-y)/2;
    }
}

例子

链接:https://ac.nowcoder.com/acm/contest/295/H
来源:牛客网

Niuniu likes playing games. He has n piles of stones. The i-th pile has ai stones. He wants to play with his good friend, UinUin. Niuniu can choose some piles out of the n piles. They will play with the chosen piles of stones. UinUin takes the first move. They take turns removing at least one stone from one chosen pile. The player who removes the last stone from the chosen piles wins the game. Niuniu wants to choose the maximum number of piles so that he can make sure he wins the game. Can you help Niuniu choose the piles? 

给出 n 个数字,记为 a_0, a_1, a_2, …, a_{n-1},且 a_i \le 5\times10^5,求 max|A|, \oplus_{a\in A} a = 0

不妨记 \oplus_{i=0}^{n-1}a_i = sum,那么我们构造一个最小的堆,使得其异或和为 sum 即可。

我们只要记录可以被构造出来的堆值,就可以在第一次这个堆被构造出来时停止。由于 x \oplus x = 0,所以我们直接用已经构造出来的集合和原始给出的集合进行逻辑运算卷积。因为不需要记录方案数,所以直接置1即可。

B_{k}^n = \left[\left(\sum_{i \oplus j = k} A_{i} B_{j}^{n-1}\right) > 0 \right],相当于记录了,有 B 的堆和 A 原来堆,异或后能构造的值的集合。显然这个数字不会超过 2^{19},所以做19次卷积即可。

#include <cstring>
#include <cstdio>
typedef long long lld;
const int MAX = 524287;
lld a[MAX+1], b[MAX+1] = {1};

void FWT(lld X[], int l, int r, int dft)
{
    if (l == r) return;
    int m = (l+r)>>1;
    FWT(X, l, m, dft);
    FWT(X, m+1, r, dft);

    for (int i = 0; i <= m-l; i++)
    {
        lld x = X[l+i], y = X[m+1+i];
        X[l+i] = dft == 1 ? x+y : (x+y)/2;
        X[m+1+i] = dft == 1 ? x-y : (x-y)/2;
    }
}

int main()
{
    int n, tmp, sum = 0; scanf("%d", &n);
    for (int i = 0; i < n; i++)
    {
        scanf("%d", &tmp);
        a[tmp]++;
        sum ^= tmp;
    }

    FWT(a, 0, MAX, 1);
    int max_ans = 0;
    while (!b[sum])
    {
        if (max_ans++ == 19) break;
        FWT(b, 0, MAX, 1);
        for (int j = 0; j <= MAX; j++)
            b[j] = a[j] * b[j];
        FWT(b, 0, MAX, -1);
        for (int j = 0; j <= MAX; j++)
            b[j] = !!b[j];
    }

    printf("%d\n", n - max_ans);
    return 0;
}

发表评论

电子邮件地址不会被公开。 必填项已用*标注