算法简介

$\text{Dsu On Tree}$,即树上启发式合并。可用来解决不带修的子树查询问题

其核心思想在于运用重链剖分中重儿子的性质进行优化,以达到简便计算的效果。


算法流程

为方便讲述,我们先引入一道例题:Codeforces 600E Lomsat gelral


题目描述

给定一棵 $n(0 \leq n \leq 10^5)$ 个节点的树,每个节点有一个编号 $d$ 和颜色 $v$。

请求出以每个节点为根的子树中最多的颜色的节点编号和。

形象化的说,设子树 $S$ 中最多的颜色的节点集合为 $V$,即求出:


分析

先考虑如何写 $O(n^2)$ 的暴力。

显然,我们可以遍历每个节点并统计其子树中的答案,然后消除其贡献以保证不对下一次统计产生影响。时间复杂度为 $O(n^2)$,无法通过本题。

这时,$\text{Dsu On Tree}$ 就可以派上用场了。

我们先进行一次 $\text{dfs}$ ,求出每个点的子树大小,并选出其重儿子。

之后对于每个节点,进行如下操作:

  • 统计出所有轻儿子的 $ans$,并消除贡献。
  • 统计其重儿子的 $ans$,不消除贡献。
  • 暴力将所有轻儿子的 $ans$ 加入到该节点的 $ans$ 中
  • 消除所有轻儿子对于该节点的影响。

用代码写出来是这个样子的

void calc(int x , int fa , int opt){    //opt 指是否需要删除的标记,为0表示需要,1表示不需要
PE(i , x){
int to = edge[i].to;
if(to == fa || to == son[x]) continue;
else{
calc(to , x , 0); //计算轻儿子的答案,并消除影响
}
}
if(son[x]) calc(son[x] , x , 1); //计算重儿子的答案,并不消除影响

add(x , fa , 1); //加入所有轻儿子的答案
ans[x] = Now_ans; //更新答案
if(opt == 0) add(x , fa , -1); //如果需要删除贡献就删除
}

这看起来是 $O(n^2)$ 的,但实际上它是 $O(n \log n)$ 的。


复杂度证明

也许你要问了:这都是暴力统计,为什么是 $O(n \log n)$ 的呐?

下面我们就来证明这个问题。

首先,我们需要知道一个性质:在一棵树上,一个节点到根节点的路径上的边数不会超过 $\log n$

然而这个性质并不能直接解决问题。我们考虑一个点会在什么时候被访问到

显然,这有两种情况,分别是:

  1. 在暴力统计轻儿子时候被访问到,次数为 $\log n$。
  2. 在统计重儿子时候被访问到,次数为 $1$。

综上所述,时间复杂度为 $O(n \log n)$。


Code[Codeforces 600E]

#include <algorithm>
#include <iostream>
#include <cstring>
#include <iomanip>
#include <cstdlib>
#include <sstream>
#include <cstdio>
#include <string>
#include <vector>
#include <cctype>
#include <stack>
#include <queue>
#include <deque>
#include <cmath>
#include <map>
#include <set>

#define I inline
#define db double
#define ll long long
#define pb push_back
#define MP make_pair
#define ldb long double
#define ull unsigned long long
#define PII pair<int , int>
#define PIL pair<int , long long>
#define PSL pair<string , long long>
#define PLL pair<long long , long long>
#define all(x) (x).begin() , (x).end()
#define copy(a , b) memcpy(a , b , sizeof(b))
#define clean(a , b) memset(a , b , sizeof(a))
#define rep(i , l , r) for (int i = (l); i <= (r); i ++)
#define per(i , r , l) for (int i = (r); i >= (l); i --)
#define PE(i , x) for(int i = head[x]; i; i = edge[i].last)
#define DEBUG(x) std :: cerr << #x << '=' << x << std :: endl

using namespace std;

const int N = 10001;
const int M = 100001;
const int HA = 998244353;
const int INF = 2047483647;
const long long INFL = 9023372036854775807;

int n;

struct Edge{
int to , last;
}edge[M << 1];

int edge_num;
int head[M << 1];

I void add_edge(int from , int to){
edge[++ edge_num] = (Edge){to , head[from]}; head[from] = edge_num;
edge[++ edge_num] = (Edge){from , head[to]}; head[to] = edge_num;
}

int size[M] , son[M];

int dfs(int x , int fa){ //统计重儿子与子树大小
size[x] = 1;
int maxn = -1;
PE(i , x){
int to = edge[i].to;
if(to == fa) continue;
size[x] += dfs1(to , x);
if(size[to] > maxn){
maxn = size[x];
son[x] = to;
}
}
return size[x];
}

ll maxn , sum , SON , col[M] , cnt[M]; //SON 指代当前点的重儿子

void add(int x , int fa , int val){ //将轻儿子的答案加入贡献中
cnt[col[x]] += val;
if(cnt[col[x]] > maxn) maxn = cnt[col[x]] , sum = col[x];
else if(cnt[col[x]] == maxn) sum += col[x];
PE(i , x){
int to = edge[i].to;
if(to == fa || to == SON) continue;
add(to , x , val);
}
}

ll ans[M];

void calc(int x , int fa , int opt){
PE(i , x){ //统计所有轻儿子的答案,并消除影响
int to = edge[i].to;
if(to == fa) continue;
else if(to != son[x]){
calc(to , x , 0);
}
}
if(son[x]) calc(son[x] , x , 1) , SON = son[x]; //统计重儿子的答案,不消除影响

add(x , fa , 1); SON = 0; //将所有轻儿子的答案加入到子树中
ans[x] = sum;
if(opt == 0) add(x , fa , -1) , maxn = 0; //如果需要则消除所有轻儿子影响
}

int main() {
#ifdef LOCAL
freopen("try.in" , "r" , stdin);
freopen("try1.out" , "w" , stdout);
#endif
scanf("%d" , &n);
rep(i , 1 , n){
scanf("%d" , &col[i]);
}
int u , v;
rep(i , 1 , n - 1){
scanf("%d%d" , &u , &v);
add_edge(u , v);
}
dfs(1 , 0);
calc(1 , 0 , 1);
rep(i , 1 , n){
printf("%lld " , ans[i]);
}

return 0;
}

THE END