「Avito Code Challenge 2018」H. K Paths

给定一个大小为 $n$ 的树,要求选出 $k$ 条路径,一个方案合法当且仅当对于路径上的边,要么只属于当前路径,要么属于所有路径,对所有合法方案 $\mathrm{mod}\ 998244353$ 输出。

$n,k \leq 10^5$。

为啥这么水的题我还写题解呢?因为我发现我的复杂度比标算更优秀!(虽然常数不太行)

显然我们从路径交开始着手计数,考虑枚举两个端点,那么所有合法的方案就是这两个端点各个子树的大小做个背包,再乘上一个组合数的系数即可,两边是独立的,所以可以分别计算然后乘法原理合并。

下面分两种情况讨论:

  • 选择的两个点没有祖先关系。这种情况很好处理,套用分治 FFT 的常规套路解决背包。然后组合数算一下,最后在 DFS 序上贡献给答案即可,比较简单所以具体细节不再赘述,这一部分的复杂度是 $\mathcal{O} (n\log ^2n)$。
  • 选择的两个点有祖先关系。这种情况标算的处理是 $\mathcal{O} (n \sqrt n)$ 的,我的方法依旧是 $\mathcal{O(n\log ^2n)}$ 的。我们还是考虑分治 FFT,先令 $f_i$ 表示第 $i$ 个点选 $K$ 个点的答案,令 $g_i$ 表示 $i$ 的所有儿子的 $f$ 值。那么我们枚举祖先,考虑其某个儿子$v$,我们在上一个情况中,$x^0$ 的系数为 $1$,而在这里 $x^0$ 的系数应该是 $g_v$,但是我们直接这样修改显然是不行的。我们考虑分治 FFT 的过程中维护两个多项式 $w_0(x),w_1(x)$,$w_0(x)$ 维护和之前一样的多项式,对于 $w_1(x)$ 维护这个范围内的二阶多项式连乘恰好有一个因式满足 $x^0$ 的系数为 $g_v$,合并时 $w_0(x) \times w_1(x) \rightarrow w_1(x)$。

总复杂度 $\mathcal{O(n\log^2n)}$。

#include <bits/stdc++.h>

using namespace std;

#define ll             long long
#define db            double
#define up(i,j,n)        for (int i = j; i <= n; i++)
#define down(i,j,n)    for (int i = j; i >= n; i--)
#define cadd(a,b)        a = add (a, b)
#define cpop(a,b)        a = pop (a, b)
#define cmul(a,b)        a = mul (a, b)
#define pr            pair<int, int>
#define fi            first
#define se            second
#define SZ(x)        (int)x.size()
#define bin(i)        (1 << (i))
#define Auto(i,node)    for (int i = LINK[node]; i; i = e[i].next)

template<typename T> inline bool cmax(T & x, T y){return y > x ? x = y, true : false;}
template<typename T> inline bool cmin(T & x, T y){return y < x ? x = y, true : false;}
template<typename T> inline T dmax(T x, T y){return x > y ? x : y;}
template<typename T> inline T dmin(T x, T y){return x < y ? x : y;}

const int mod = 998244353;
const int MAXN = 4e5 + 5;
const int oo = 0x3f3f3f3f;
const int L = 3e5;

inline int add(int a, int b){a += b; return a >= mod ? a - mod : a;}
inline int pop(int a, int b){a -= b; return a < 0 ? a + mod : a;}
inline int mul(int a, int b){return (ll)a * b % mod;}

int qpow(int a, int b){
    int c = 1;
    while (b) {
        if (b & 1) cmul(c, a);
        cmul(a, a); b >>= 1;
    }
    return c;
}

int omega[MAXN], inv[MAXN], A[MAXN], B[MAXN], pos[MAXN], ans = 0;

int fac[MAXN], invfac[MAXN];

inline int C(int a, int b){
    if (a < 0 || b < 0 || a < b) return 0;
    return mul(fac[a], mul(invfac[b], invfac[a - b]));
}

int init(int N){
    int n = 1;
    while (n < N) n <<= 1;
    up (i, 0, n - 1) {
        pos[i] = pos[i >> 1] >> 1;
        if (i & 1) pos[i] |= n >> 1;
    }
    omega[0] = inv[0] = 1;
    int x = qpow(3, (mod - 1) / n), y = qpow(x, mod - 2);
    up (i, 1, n - 1) {
        omega[i] = mul(omega[i - 1], x);
        inv[i] = mul(inv[i - 1], y);
    }
    return n;
}

void FFT(int n, int *a, bool idft){
    up (i, 0, n - 1) if (pos[i] < i) swap(a[pos[i]], a[i]);
    int *w = idft ? inv : omega;
    for (int len = 2; len <= n; len <<= 1) {
        int m = len >> 1;
        for (int j = 0; j < n; j += len) up (k, 0, m - 1) {
            int z = mul(a[j + m + k], w[n / len * k]);
            a[j + m + k] = pop(a[j + k], z);
            a[j + k] = add(a[j + k], z);
        }
    }
    if (idft) {
        int x = qpow(n, mod - 2);
        up (i, 0, n - 1) cmul(a[i], x);
    }
}

int N, K, dfn[MAXN], rev[MAXN], ord, siz[MAXN], suf[MAXN];
int f[MAXN], sz[MAXN], m, w[21][MAXN], w1[21][MAXN], all[21];

int g[MAXN], _w[MAXN];

pr Work(int le, int ri, int dep){
    if (le == ri) {
        int cl = all[dep] + 1, cr = cl + 1;
        w[dep][cl] = 1; w[dep][cr] = sz[le];
        all[dep] = cr;
        return make_pair(cl, cr);
    }
    int mi = (le + ri) >> 1, cl = all[dep] + 1, cr = cl + ri - le + 1;
    all[dep] = cr;
    pr p = Work(le, mi, dep + 1), q = Work(mi + 1, ri, dep + 1);
    int n = init(ri - le + 2);
    memset(A, 0, sizeof(int) * n);
    memset(B, 0, sizeof(int) * n);
    up (i, p.fi, p.se) A[i - p.fi] = w[dep + 1][i];
    up (i, q.fi, q.se) B[i - q.fi] = w[dep + 1][i];
    FFT(n, A, 0); FFT(n, B, 0);
    up (i, 0, n - 1) cmul(A[i], B[i]);
    FFT(n, A, 1);
    up (i, cl, cr) w[dep][i] = A[i - cl];
    return make_pair(cl, cr);
}

pr re_Work(int le, int ri, int dep){
    if (le == ri) {
        int cl = all[dep] + 1, cr = cl + 1;
        w[dep][cl] = 1; w[dep][cr] = sz[le];
        w1[dep][cl] = _w[le]; w1[dep][cr] = 0;
        all[dep] = cr;
        return make_pair(cl, cr);
    }
    int mi = (le + ri) >> 1, cl = all[dep] + 1, cr = cl + ri - le + 1;
    all[dep] = cr;
    pr p = re_Work(le, mi, dep + 1), q = re_Work(mi + 1, ri, dep + 1);
    int n = init(ri - le + 2);
    up (i, cl, cr) w[dep][i] = w1[dep][i] = 0;
    int *ss[2] = {w[dep + 1], w1[dep + 1]};
    int *tt[2] = {w[dep], w1[dep]};
    up (x, 0, 1) up (y, 0, 1) if (x + y < 2) {
        memset(A, 0, sizeof(int) * n);
        memset(B, 0, sizeof(int) * n);
        up (i, p.fi, p.se) A[i - p.fi] = ss[x][i];
        up (i, q.fi, q.se) B[i - q.fi] = ss[y][i];
        FFT(n, A, 0);
        FFT(n, B, 0);
        up (i, 0, n - 1) cmul(A[i], B[i]);
        FFT(n, A, 1);
        up (i, cl, cr) cadd(tt[x + y][i], A[i - cl]);
    }
    return make_pair(cl, cr);
}

int calc(){
    if (m == 0) return 1;
    up (i, 0, 20) all[i] = 0;
    pr ww = Work(1, m, 0);
    int cl = ww.fi, cr = ww.se, sum = 0;
    cmin(cr, cl + K);
    up (i, cl, cr) cadd(sum, mul(mul(w[0][i], fac[i - cl]), 
        C(K, K - (i - cl))));
    return sum;
}

struct edge {
    int y, next;
}e[MAXN << 1];

int LINK[MAXN], len = 0;

inline void ins(int x, int y){
    e[++len].next = LINK[x]; LINK[x] = len;
    e[len].y = y;
}

inline void Ins(int x, int y){
    ins(x, y);
    ins(y, x);
}

void DFS(int node, int fa){
    dfn[node] = ++ord; rev[ord] = node;
    siz[node] = 1;
    Auto (i, node) if (e[i].y != fa) {
        DFS(e[i].y, node);
        siz[node] += siz[e[i].y];
    }
    m = 0;
    Auto (i, node) if (e[i].y != fa) 
        sz[++m] = siz[e[i].y];
    f[node] = calc();
}

void reDFS(int node, int fa){
    int son = 0;
    g[node] = f[node];
    Auto (i, node) if (e[i].y != fa) {
        reDFS(e[i].y, node);
        cadd(g[node], g[e[i].y]);
        son++;
    }
    m = 0;
    Auto (i, node) if (e[i].y != fa) { 
        sz[++m] = siz[e[i].y];
        _w[m] = g[e[i].y];
    }
    if (fa) {
        sz[++m] = N - siz[node], son++;
        _w[m] = 0;
    }
    if (m) {
        up (i, 0, 20) all[i] = 0;
        pr ww = re_Work(1, m, 0);
        int cl = ww.fi, cr = ww.se;
        cmin(cr, cl + K);
        up (i, cl, cr) cadd(ans, mul(w1[0][i], 
        mul(C(K, K - (i - cl)), fac[i - cl])));
    }
}

int main(){
#ifdef cydiater
    freopen("input.in", "r", stdin);
#endif
    scanf("%d%d", &N, &K);
    if (K == 1) return 0 * printf("%lld\n", ((ll)N * (N - 1) / 2) % mod);
    fac[0] = invfac[0] = invfac[1] = 1;
    up (i, 2, L) invfac[i] = mul(mod - mod / i, invfac[mod % i]);
    up (i, 1, L) fac[i] = mul(i, fac[i - 1]);
    up (i, 1, L) invfac[i] = mul(invfac[i], invfac[i - 1]);
    up (i, 2, N) {
        int x, y; scanf("%d%d", &x, &y);
        Ins(x, y);
    }
    DFS(1, 0);
    down (i, N, 1) {
        cadd(ans, mul(f[rev[i]], suf[i + siz[rev[i]]]));
        suf[i] = add(suf[i + 1], f[rev[i]]);
    }
    reDFS(1, 0);
    printf("%d\n", ans);
    return 0;
}