🌑

小羊儿的心情天空

2019CCPC秦皇岛 H. Houraisan Kaguya

Oct 13, 2019 由 小羊

题目链接

首先,如果你的群论足够好,你应该能理解到,

f(a,b)=ordagcd(orda,ordb)f(a,b) = \frac{\operatorname{ord}a}{\gcd(\operatorname{ord}a,\operatorname{ord}b)}

所以题目实际上是求

i=1nj=1nordai×ordajgcd(ordai,ordaj)2\sum_{i=1}^n \sum_{j=1}^n \frac{\operatorname{ord}a_i \times \operatorname{ord}a_j}{\gcd(\operatorname{ord}a_i,\operatorname{ord}a_j)^2}

然后现场赛到这里我就不会了(雾)

实际上,我们可以按莫比乌斯反演的套路想到,枚举 gcd。我们考虑到 ordx\operatorname{ord}x 也一定是 p1p-1 的倍数,所以考虑这样的一个卷积

ck=gcd(i,j)=kijfifj=i,jifi jfj[gcd(i,j)=k]\begin{aligned} c_k &= \sum_{\gcd(i, j) = k} ijf_if_j \\ &=\sum_{i,j} if_i ~ jf_j [\gcd(i,j)=k] \end{aligned}

其中 fif_i 表示 ordax=i\operatorname{ord} a_x = i 的个数,那么考虑枚举 gcd\gcd 的倍数,则有

ck=xcx[kx]=i,jifi jfj[kgcd(i,j)]=i,jifi jfj[ki][kj]=iifi[ki]jjfj[kj]\begin{aligned} {c'}_k &= \sum_{x}c_x[k|x] \\ &= \sum_{i,j} if_i ~ jf_j [k|\gcd(i,j)] \\ &= \sum_{i,j} if_i ~ jf_j [k|i][k|j] \\ &= \sum_{i} if_i [k|i] \sum_{j} jf_j [k|j] \end{aligned}

也就是需要作一个类似于

ck=xcx[kx]{c'}_k = \sum_{x} c_x[k|x]

的正变换和逆变换。考虑到 FWT 中那个类似于按位与卷积的东西,设计一个类似于高位前缀和的东西,就可以做了。

另外求 ordx\operatorname{ord}x 也有一点注意。朴素的想法是先将 xx 拉入模 pp 非零元素乘法群的一个 piei{p_i}^{e_i} 阶子群中,也就是先设计 xi=xp1pieix_i = x^{\frac{p-1}{ {p_i}^{e_i} }},然后求 xix_i 在该子群中的周期。所以可以设计一个时间复杂度为 O(c(p1)log(p1))O(c(p-1) \log(p-1)) 的算法。在此过程中,计算的瓶颈在于 xix_i 的计算,耗费了绝大部分的时间复杂度,那么我们可以设计一个分治的算法,将因子集合划分为两部分,前半部分将后半部分素因子子群的周期升阶到 11,后半部分同样操作前半部分,然后继续分治。类似于多项式多点求值的一个想法,时间复杂度降为 O(logc(p1)log(p1))O(\log c(p-1) \log(p-1))

目前这份代码是CF上跑的最快的,756ms,用普通求阶时间为 1575ms。

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
typedef long double flt;
const int MAXN = 2e5+5, S = 10;
int n, w, fc[20], dcn;
lld mod, fp[20], zs[20];

inline lld mul(lld x, lld y, lld z) { lld t = flt(x) * y / z; lld ans = (x * y - t * z) % z; if (ans < 0) ans += z; return ans; }
inline lld mul(lld x, lld y) { lld t = flt(x) * y / mod; lld ans = (x * y - t * mod) % mod; if (ans < 0) ans += mod; return ans; }
inline void addeq(lld &x, lld y) { x = x+y - (x+y>=mod?mod:0); }
inline void muleq(lld &x, lld y) { lld t = flt(x) * y / mod; lld ans = (x * y - t * mod) % mod; if (ans < 0) ans += mod; x = ans; }
inline lld fpow(lld x, lld n, lld mod) { x %= mod; if (n == 1) return x; lld ret = 1; for (; n; n>>=1, x=mul(x,x,mod)) if (n&1) ret=mul(ret,x,mod); return ret; }
inline lld fpow(lld x, lld n) { x %= mod; if (n == 1) return x; lld ret = 1; for (; n; n>>=1, x=mul(x,x,mod)) if (n&1) ret=mul(ret,x,mod); return ret; }

namespace Decompose {
    int tol; lld factor[1000];

    bool millerRabin(lld n, lld base) {
        lld n2 = n-1, s = __builtin_ctzll(n2); n2 >>= s;
        lld t = fpow(base, n2, n);
        if (t == 1 || t == n-1) return true;
        for (s--; s >= 0; s--)
            if ((t=mul(t,t,n))==n-1)
                return true;
        return false;
    }

    bool isPrime(lld n) {
        static lld bases[12] = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37 };
        if (n <= 2) return false;
        for (int i = 0; i < 12 && bases[i] < n; i++)
            if (!millerRabin(n, bases[i])) return false;
        return true;
    }

    inline lld f(lld x, lld mod) {
        lld ans = 1; addeq(ans, mul(x, x, mod)); return ans;
    }

    void PollardRho(lld n) {
        if (n == 1) {
            return;
        } else if (isPrime(n)) {
            factor[tol++] = n;
        } else if (n & 1) {
            for (int i = 1; ; i++) {
                lld x = i, y = f(x, n), q = __gcd(y-x, n);
                while (q == 1) {
                    x = f(x, n), y = f(f(y, n), n);
                    q = __gcd((y-x+n)%n, n) % n;
                }
                if (q != 0 && q != n) {
                    PollardRho(q), PollardRho(n/q);
                    return;
                }
            }
        } else {
            while (!(n & 1)) factor[tol++] = 2, n >>= 1;
            PollardRho(n);
        }
    }

    inline void findfac(lld n) {
        tol = w = 0; PollardRho(n); map<lld,int> tj;
        for (int i = 0; i < tol; i++) tj[factor[i]]++;
        for (auto ss : tj) {
            fc[w] = ss.second, fp[w] = ss.first;
            zs[w] = 1; for (int i = 0; i < fc[w]; i++) zs[w] *= fp[w];
            w++;
        }
    }
}

using Decompose::findfac;

lld getOrd(lld a) {
    lld res = mod-1;
    for (int i = 0; i < w; i++) {
        lld qwq = fpow(a, res/=zs[i]);
        while (qwq != 1) qwq = fpow(qwq, fp[i]), res *= fp[i];
    }
    return res;
} // Time complexity: O(c(n)log(n))

struct Node { int l, r; lld val; };

lld getOrdFast(lld a) {
    static Node Q[100000]; int front = 0, rear = 0;
    lld res = 1; Q[rear++] = Node { 0, w-1, a };
    while (front < rear) {
        Node f = Q[front++];
        if (f.l == f.r) {
            for (lld bs = f.val; bs != 1; bs = fpow(bs, fp[f.l]))
                res *= fp[f.l];
        } else if (f.l < f.r) {
            int mid = (f.l+f.r)>>1;
            lld prod1 = 1, prod2 = 1;
            for (int i = f.l; i <= mid; i++) prod2 *= zs[i];
            for (int i = mid+1; i <= f.r; i++) prod1 *= zs[i];
            Q[rear++] = Node { f.l, mid, fpow(f.val, prod1) };
            Q[rear++] = Node { mid+1, f.r, fpow(f.val, prod2) };
        }
    }
    return res;
} // Time complexity: O(log(c(n))log(n))

lld di[MAXN], invdi2[MAXN]; int qp[20];
lld a[MAXN], g[MAXN];

void genFac(int wi, lld fac, int ids) {
    if (wi == w) {
        di[ids] = fac;
        invdi2[ids] = fac==1?1:mod-mod/fac;
        muleq(invdi2[ids], invdi2[ids]);
    } else {
        lld pr = 1;
        for (int i = 0; i <= fc[wi]; i++, pr *= fp[wi])
            genFac(wi+1, fac*pr, ids+i*qp[wi]);
    }
}

void addResult(lld ord)
{
    int id = 0; lld n = ord;

    for (int i = 0; i < w; i++)
    {
        int r = 0;
        while (n % fp[i] == 0) n /= fp[i], r++;
        id += r * qp[i];
    }

    addeq(g[id], ord);
}

void solve()
{
    findfac(mod - 1);

    qp[0] = 1;
    for (int i = 1; i < w; i++)
        qp[i] = qp[i-1] * (fc[i-1] + 1);
    dcn = qp[w-1] * (fc[w-1] + 1);
    genFac(0, 1, 0);

    for (int i = 1; i <= n; i++)
        addResult(getOrdFast(a[i]));
    for (int i = 0; i < w; i++)
        for (int j = dcn-1; j >= 0; j--)
            if (j / qp[i] % (fc[i]+1) != fc[i])
                addeq(g[j], g[j+qp[i]]);
    for (int i = 0; i < dcn; i++)
        muleq(g[i], g[i]);
    for (int i = 0; i < w; i++)
        for (int j = 0; j < dcn; j++)
            if (j / qp[i] % (fc[i]+1) != fc[i])
                addeq(g[j], mod-g[j+qp[i]]);
    lld ans = 0;
    for (int i = 0; i < dcn; i++)
        addeq(ans, mul(g[i], invdi2[i]));
    printf("%lld\n", ans);
}

int main()
{
    srand(time(NULL));
    scanf("%d %lld", &n, &mod);
    for (int i = 1; i <= n; i++)
        scanf("%lld", &a[i]);
    solve();
    return 0;
}