「BZOJ 4775」网管

给定一棵大小为 $n$ 的树,每个点有一个黑或白的颜色,给定初始颜色,要求支持:

  • 以 $P$ 的概率翻转某个点的颜色。
  • 求某个点到所有黑点的距离和的平方的期望值。

容易发现,如果第二问求的是到所有黑点距离的期望值,那么就很容易用动态点分治来维护,为了方便下面的推导,我们直接令 $X$ 为某个点到所有黑点的距离和的期望值。

我们要求的就是 $E(X^2)$,这个东西,首先给出一个公式:

$$E(X^2) - E^2(X) = Var(X)$$

即一个事件平方的期望与期望的平方之差为其方差。$E(X)$是可以方便用动态点分治来维护的,考虑怎么快速求出 $Var(X)$,下面再给出几个公式:

$$Var(\sum X_i) = \sum Var(X_i) + \sum\limits_{i \not = j} Cov(X_i, X_j)$$

$$Var(ax) = a^2Var(x)$$

其中 $Cov$ 表示协方差,这道题里只需要清楚当 $X_i,X_j$ 的取值是独立时,$Cov(X_i,X_j) = 0$ 即可。

基本清楚这些知识后,我们考虑所求的内容:

$$Var(\sum dis(x,i)p_i) = \sum Var(dis(x, i)p_i) = \sum dis^2(x,i)Var(p_i) = \sum dis^2(x,i)p_i(1 - p_i)$$

对于每个点,我们把式子拆开,分别维护即可。

时间复杂度 $O(n\log n + Q\log ^2n)$。

#include <bits/stdc++.h>

using namespace std;

#define ll             long long
#define db            long double
#define up(i,j,n)        for (int i = j; i <= n; i++)
#define down(i,j,n)    for (int i = j; i >= n; i--)
#define cadd(a,b)        a = add (a, b)
#define cpop(a,b)        a = pop (a, b)
#define cmul(a,b)        a = mul (a, b)
#define pr            pair<int, int>
#define fi            first
#define se            second
#define SZ(x)        (int)x.size()
#define bin(i)        (1 << (i))
#define Auto(i,node)    for (int i = LINK[node]; i; i = e[i].next)

template<typename T> inline bool cmax(T & x, T y){return y > x ? x = y, true : false;}
template<typename T> inline bool cmin(T & x, T y){return y < x ? x = y, true : false;}

const int MAXN = 1e5 + 5;
const int oo = 0x3f3f3f3f;

int N, Q, Type, Fa[MAXN];
db a[MAXN];

struct CALC {
    db cnt, yp, y2p, cnt2, yp2, y2p2;
    inline void ins(db p, int dis){
        cnt += p; yp += p * dis;
        y2p += p * dis * dis;
        p *= p;
        cnt2 += p; yp2 += p * dis;
        y2p2 += p * dis * dis;
    }
    inline void del(db p, int dis){
        cnt -= p; yp -= p * dis;
        y2p -= p * dis * dis;
        p *= p;
        cnt2 -= p; yp2 -= p * dis;
        y2p2 -= p * dis * dis;
    }
}S[MAXN], T[MAXN];

struct edge {
    int y, next;
}e[MAXN << 1];

int LINK[MAXN], len = 0, mxsiz[MAXN], siz[MAXN], all, root, fa[MAXN][21], dep[MAXN];
bool vis[MAXN];

inline void ins(int x, int y){
    e[++len].next = LINK[x]; LINK[x] = len;
    e[len].y = y;
}

inline void Ins(int x, int y){
    ins(x, y);
    ins(y, x);
}

int LCA(int x, int y){
    if (dep[x] < dep[y]) swap(x, y);
    down (i, 20, 0) if (fa[x][i] && dep[fa[x][i]] >= dep[y]) x = fa[x][i];
    if (x == y) return x;
    down (i, 20, 0) if (fa[x][i] && fa[x][i] != fa[y][i]) {
        x = fa[x][i];
        y = fa[y][i];
    }
    return fa[x][0];
}

int getdis(int x, int y) { return dep[x] + dep[y] - (dep[LCA(x, y)] << 1); }

void DFS(int node){
    up (i, 1, 20) fa[node][i] = fa[fa[node][i - 1]][i - 1];
    Auto (i, node) if (e[i].y != fa[node][0]) {
        fa[e[i].y][0] = node;
        dep[e[i].y] = dep[node] + 1;
        DFS(e[i].y);
    }
}

int getrt(int node, int la){
    mxsiz[node] = 0; siz[node] = 1;
    Auto (i, node) if (!vis[e[i].y] && e[i].y != la) {
        getrt(e[i].y, node);
        siz[node] += siz[e[i].y];
        cmax(mxsiz[node], siz[e[i].y]);
    }
    cmax(mxsiz[node], all - siz[node]);
    if (mxsiz[root] > mxsiz[node]) root = node;
}

void Fix(int node, int la){
    mxsiz[node] = 0; siz[node] = 1;
    Auto (i, node) if (!vis[e[i].y] && e[i].y != la) {
        Fix(e[i].y, node);
        siz[node] += siz[e[i].y];
    }
}

void DFS(int node, int la, int dis, CALC & ss){
    ss.ins(a[node], dis);
    Auto (i, node) if (e[i].y != la && !vis[e[i].y]) 
        DFS(e[i].y, node, dis + 1, ss);
}

void Work(){
    int node = root; vis[node] = 1;
    Fix(node, 0);
    DFS(node, 0, 0, S[node]);
    Auto (i, node) if (!vis[e[i].y]) {
        all = siz[e[i].y]; root = 0; 
        getrt(e[i].y, node);
        DFS(e[i].y, node, 1, T[root]);
        Fa[root] = node;
        Work();
    }
}

void cg(int node, db w){
    int x = node;
    w = a[node] * (1 - w) + (1 - a[node]) * w;
    while (x) {
        int dis = getdis(x, node);
        S[x].del(a[node], dis);
        S[x].ins(w, dis);
        if (Fa[x]) {
            dis = getdis(Fa[x], node);
            T[x].del(a[node], dis);
            T[x].ins(w, dis);
        }
        x = Fa[x];
    }
    a[node] = w;
}

db get(int node){
    int x = node; db ex = 0, vx = 0;
    while (x) {
        int dis = getdis(x, node);
        ex += S[x].cnt * dis;
        ex += S[x].yp;
        vx += S[x].cnt * dis * dis;
        vx += S[x].y2p;
        vx += S[x].yp * 2 * dis;
        vx -= S[x].cnt2 * dis * dis;
        vx -= S[x].y2p2;
        vx -= S[x].yp2 * 2 * dis;
        if (Fa[x]) {
            dis = getdis(Fa[x], node);
            ex -= T[x].cnt * dis;
            ex -= T[x].yp;
            vx -= T[x].cnt * dis * dis;
            vx -= T[x].y2p;
            vx -= T[x].yp * 2 * dis;
            vx += T[x].cnt2 * dis * dis;
            vx += T[x].y2p2;
            vx += T[x].yp2 * 2 * dis;    
        }
        x = Fa[x];
    }
    return vx + ex * ex;
}

int main(){
    scanf("%d%d%d", &Type, &N, &Q);
    up (i, 1, N) {
        double x; scanf("%lf", &x);
        a[i] = x;
    }
    up (i, 2, N) {
        int x, y; scanf("%d%d", &x, &y);
        Ins(x, y);
    }
    dep[1] = 1;
    DFS(1);
    all = N; mxsiz[root = 0] = N;
    getrt(1, 0);
    Work();
    while (Q--) {
        int o, x, v; scanf("%d", &o);
        if (o == 1) {
            scanf("%d%d", &x, &v);
            db w = (db)v / 100;
            cg(x, w);
        }else {
            scanf("%d", &x);
            db ans = get(x);
            printf("%.10lf\n", (double)ans);
        }
    }
    return 0;
}