「PA2015」Rozstaw szyn

给定一棵大小为 $n$ 的树,树有 $m$ 个叶子,每个叶子都有一个初始的权值。要求给其余 $n - m$ 个节点赋值,满足任意相邻两个节点权值差绝对值的和最小。

$n \leq 5\times10^5$。

一开始感觉是个傻逼题,但是想了几个做法都叉掉了,就去看题解了。

首先我们考虑一种特殊情况,就是当给定的树是菊花图的情况,那么这个时候显然我们根节点应该取的值为所有叶子节点权值的中位数,同时注意到,如果存在偶数个叶子节点的话,那么我们取的值在两个中位数之间都是合法的。

考虑把这个特殊情况扩展出来,我们定义一个节点合法的取值范围为当这个点的取值在这个取值范围内时,对答案的贡献是最小的。对于上面所说的情况,把每个儿子的权值换做取值范围来考虑,那么如何求得当前节点的合法取值范围呢,一个结论就是这个节点的取值范围是所有儿子的端点排序后的两个中位数。简单证明一下就是,很明显我们的取值范围一定是相邻两个端点,要求这个区间内所有数的代价相等且最小,相等时很显然的,因为这两个端点左右的完整区间数量一定是相等的,所以贡献的代价也是相同的。同时发现这个代价是单峰的,我们向左向右移动都会使代价增加,那么这一段就是贡献最小的区间。

这个时候考虑证明为什么从儿子的取值范围推出父亲的取值范围是正确的,考虑对于一个儿子来说,他如果向右移动取值范围,可以减少和其父亲产生的代价,但是因为其取值范围是在正中间,势必会造成影响其儿子对他的贡献增加至少一个相同的值,所以调整一定不会更优。

复杂度就是排序的复杂度 $\mathcal{O(n\log n)}$。

#include <bits/stdc++.h>

using namespace std;

#define ll             long long
#define db            double
#define up(i,j,n)        for (int i = j; i <= n; i++)
#define down(i,j,n)    for (int i = j; i >= n; i--)
#define cadd(a,b)        a = add (a, b)
#define cpop(a,b)        a = pop (a, b)
#define cmul(a,b)        a = mul (a, b)
#define pr            pair<int, int>
#define fi            first
#define se            second
#define SZ(x)        (int)x.size()
#define bin(i)        (1 << (i))
#define Auto(i,node)    for (int i = LINK[node]; i; i = e[i].next)

template<typename T> inline bool cmax(T & x, T y){return y > x ? x = y, true : false;}
template<typename T> inline bool cmin(T & x, T y){return y < x ? x = y, true : false;}
template<typename T> inline T dmax(T x, T y){return x > y ? x : y;}
template<typename T> inline T dmin(T x, T y){return x < y ? x : y;}

inline int read(){
    char ch = getchar(); int x = 0, f = 1;
    while (ch > '9' || ch < '0') {if (ch == '-') f = -1; ch = getchar();}
    while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0'; ch = getchar();}
    return x * f;
}

const int MAXN = 5e5 + 5;

int N, M, LINK[MAXN], len, cl[MAXN], cr[MAXN];
ll ans;

struct edge {
    int y, next;
}e[MAXN << 1];

inline void ins(int x, int y){
    e[++len].next = LINK[x]; LINK[x] = len;
    e[len].y = y;
}

inline void Ins(int x, int y){
    ins(x, y);
    ins(y, x);
}

int num[MAXN << 1], top;

void DFS(int node, int fa){
    if (node <= M) return;
    Auto (i, node) if (e[i].y != fa) DFS(e[i].y, node);
    top = 0;
    Auto (i, node) if (e[i].y != fa) num[++top] = cl[e[i].y],
    num[++top] = cr[e[i].y];
    sort(num + 1, num + top + 1);
    cl[node] = num[top >> 1]; cr[node] = num[(top >> 1) + 1];
    Auto (i, node) if (e[i].y != fa && 
        (cr[e[i].y] <= cl[node] || cr[node] <= cl[e[i].y])) 
            ans += dmin(abs(cl[node] - cl[e[i].y]), 
                abs(cl[node] - cr[e[i].y]));
}

int main(){
#ifdef cydiater
    freopen("input.in", "r", stdin);
#endif
    N = read(); M = read();
    up (i, 2, N) {
        int x = read(), y = read();
        Ins(x, y);
    }
    up (i, 1, M) cl[i] = cr[i] = read();
    ans = 0;
    if (N > M) DFS(N, 0);
    else {
        up (node, 1, N) Auto (i, node) 
            ans += abs(cl[node] - cl[e[i].y]);
        ans /= 2;
    }
    printf("%lld\n", ans);
    return 0;
}