算法起源

在算法竞赛中,我们经常会碰到一些与异或运算(即 C++ 中 ^) 相关的题目,其套路在于按位考虑,即将十进制数拆为二进制数后进行计算

但对于二进制位的直接维护并不是一件方便的事。故我们将字典树修改为 $\text{01-Trie}$,方便对二进制位进行维护与查询。


算法讲解

$\text{01-Trie}$ 的应用主要有两种,一是维护异或极值,二是维护异或和。

维护异或极值

为方便讲述,我们先引入一道例题:Luogu P4551 最长异或路径


题目描述

给你一颗 $n$ 个点的树,边有边权,求出两个点 $u$ 与 $v$ 使得 $u$ 到 $v$ 路径上的边权异或和最大。

$1 \leq n \leq 10^5$,边权的范围为 $\lbrack 0 , 2^{31} \rbrack$。


分析

(由于异或没有较为统一的符号表示,本文中统一使用 $\oplus$ 表示异或运算)。

设 $X(u , v)$ 表示 $u$ 到 $v$ 路径上的边权异或和。

众所周知,异或具有可抵消的特点,即 $a \oplus a = 0$。

而 $X(root , i) , i \in \lbrack 1 , n \rbrack$ 可以用一遍 $\text{dfs}$ 在 $O(n)$ 的时间复杂度内与处理出来。

之后将 $X(root , i) , i \in \lbrack 1 , n \rbrack$ 全部插入到 $\text{01-Trie}$ 中,插入代码如下:

int ch[M][2] , cnt = 1; //ch[i][0/1] 表示i号节点的两个儿子的编号,ch[i][0] 指下一位为0,ch[i][1] 指下一位为1
//cnt 指节点数目,由于 Trie 树初始时有一个无实际意义的根节点,故初始化为1

void insert(int x){
int now = 1;
for(int i = 30; i >= 0; i --){
int t = ((x >> i) & 1); //取出二进制上第i位的数
if(ch[now][t]) now = ch[now][t]; //如果有的话,就直接走过去即可
else {
ch[now][t] = ++ cnt; //否则就新建一个
now = ch[now][t];
}
}
}

现在的问题是如何查询,考虑一下按位异或的定义

  • 对于两个数的所有二进制位从后向前进行比较,若相同则为 $0$,不同则为 $1$。

且二进制数的相对大小取决于高位,即高位大的一定大。故我们只需要根据高位进行贪心,让位于高位的两个二进制数尽量不同即可。查询代码如下:

int query(int x){
int now = 1 , ans = 0;
for(int i = 30; i >= 0; i --){
int t = ((x >> i) & 1);
if(ch[now][t ^ 1]) now = ch[now][t ^ 1] , ans = (1 << i); //如果两个二进制数可以不同,则设为不同,且记录答案
else now = ch[now][t]; //否则只能相同了
}
return ans;
}

Code[Luogu P4551]

#include <algorithm>
#include <iostream>
#include <cstring>
#include <iomanip>
#include <cstdlib>
#include <climits>
#include <cstdio>
#include <bitset>
#include <string>
#include <vector>
#include <cctype>
#include <stack>
#include <queue>
#include <deque>
#include <cmath>
#include <map>
#include <set>

#define I inline
#define fi first
#define se second
#define pb push_back
#define MP make_pair
#define PII pair<int , int>
#define PSL pair<string , long long>
#define PLL pair<long long , long long>
#define all(x) (x).begin() , (x).end()
#define copy(a , b) memcpy(a , b , sizeof(b))
#define clean(a , b) memset(a , b , sizeof(a))
#define rep(i , l , r) for (int i = (l); i <= (r); i ++)
#define per(i , r , l) for (int i = (r); i >= (l); i --)
#define PE(i , x) for(int i = head[x]; i; i = edge[i].last)
#define DEBUG(x) std :: cerr << #x << '=' << x << std :: endl

typedef long long ll;
typedef unsigned long long ull;

template <class T>
inline void read(T &x) {
char c = getchar(), f = 0; x = 0;
while (!isdigit(c)) f = (c == '-'), c = getchar();
while (isdigit(c)) x = x * 10 + c - '0', c = getchar();
x = f ? -x : x;
}

using namespace std;

const int N = 10000 + 5;
const int M = 100000 + 5;
const int HA = 998244353;

struct Edge{
int to , last;
ll dis;
}edge[M << 1];

int edge_num;
int head[M << 1];

I void add_edge(int from , int to , ll dis){
edge[++ edge_num] = (Edge){to , head[from] , dis}; head[from] = edge_num;
edge[++ edge_num] = (Edge){from , head[to] , dis}; head[to] = edge_num;
}

int fa[M] , dis[M];

void dfs(int x , int f){
fa[x] = f;
PE(i , x){
int to = edge[i].to;
if(to == f) continue;
dis[to] = (dis[x] ^ edge[i].dis);
dfs(to , x);
}
}

namespace Trie{
#define lson(x) ch[x][0]
#define rson(x) ch[x][1]

int ch[M << 5][2] , cnt = 1;

void insert(int x){
int now = 1;
per(i , 30 , 0){
int t = ((x >> i) & 1);
if(ch[now][t]) now = ch[now][t];
else ch[now][t] = ++ cnt , now = ch[now][t];
}
}

int query(int x){
int now = 1 , ans = 0;
per(i , 30 , 0){
int t = ((x >> i) & 1);
if(ch[now][t ^ 1]) now = ch[now][t ^ 1] , ans += (1 << i);
else now = ch[now][t];
}
return ans;
}

#undef lson
#undef rson
}

using namespace Trie;

int main() {
#ifdef LOCAL
freopen("try.in" , "r" , stdin);
freopen("try1.out" , "w" , stdout);
#endif
int n; read(n);
rep(i , 1 , n - 1){
int u , v; ll w;
read(u) , read(v) , read(w);
add_edge(u , v , w);
}
dis[1] = 0;
dfs(1 , 0);
rep(i , 1 , n){
insert(dis[i]);
}
ll res = 0;
rep(i , 1 , n){
res = max(res , 1ll * query(dis[i]));
}
printf("%d\n" , res);

return 0;
}

维护异或和

$\text{01-Trie}$ 用于维护数字异或和时可支持插入,删除和全局加一(即让维护的所有数值 $+1$)。维护时需要按值从低位到高位建树

且在维护异或和时,我们仅需要知道某一位上 $0$ $1$ 数目的奇偶性即可,也就是说,对于某一位上的 $1$ 而言,当且仅当这一位上数字 $1$ 的个数为奇数个时,这一位上的数字才是 $1$。并不需要知道我们维护了哪些数字。

举个例子:

其中最高位上数字 $1$ 的个数为 $3$ ,是个奇数,故最后结果中最高位为 $1$,而后面三位中 $1$ 的个数分别为 $2 , 2 , 0$,均不为奇数,故最后结果中后面三位均为 $0$。


单节点维护

对于每一个节点,我们需要维护以下三个信息:

  • ch[x][0] & ch[x][1] 即左右儿子,ch[x][0] 代表下一位是 $0$ ,ch[x][1] 代表下一位是 $1$。

  • w[x] 指节点 $x$ 到其父亲节点这条边的权值。每当插入一个数字 $n$ 时,$n$ 二进制拆分后在 $\text{01-Trie}$ 上的权值都会 $+1$。

  • xorv[x] 指以 $x$ 为根节点的子树维护的异或和。

代码也很好写:

void update(int x){
w[x] = xorv[x] = 0;
if(ch[x][0]){
w[x] += w[ch[x][0]];
xorv[x] ^= (xorv[ch[x][0]] << 1);
}
if(ch[x][1]){
w[x] += w[ch[x][0]];
xorv[x] ^= (xorv[ch[x][0]] << 1) | (w[ch[x][1] & 1);
}
}

插入 & 删除

插入与删除操作也很容易,首先,我们维护一个最大深度 $H$,强行把每个叶子节点到根的距离设置为 $H$,以方便之后的进位处理。

之后按位插入并修改叶子节点的 w[x] 即可。

代码如下:

const int H = 20;

int mknode(){ //新建一个节点
cnt ++ , ch[cnt][0] = ch[cnt][1] = w[cnt] = xorv[cnt] = 0;
return cnt;
}

void insert(int &x , int val , int dep){
if(!x) x = mknode(); //如果当前没有这个节点,就新建一个
if(dep > H) return (void)(w[x] ++); //如果到了叶子节点,就修改 w[x]
insert(ch[x][val & 1] , val >> 1 , dep + 1); //否则递归处理 val 的每一位
update(x); //之后合并节点两个子树的信息
}

void Delete(int x , int val , int dep){
if(dep > H) return (void)(w[x] --); //同上
Delete(ch[x][val & 1] , val >> 1 , dep + 1);
update(x);
}

合并

$\text{01-Trie}$ 的合并类似于线段树合并,分三种情况讨论一下即可

int merge(int a , int b){
if(!a || !b) return a + b; //若 a 与 b 的任意一颗子树为空,则直接返回即可

xorv[a] ^= xorv[b];
w[a] += w[b];

ch[a][0] = merge(ch[a][0] , ch[b][0]); //分别合并左右子树
ch[a][1] = merge(ch[a][1] , ch[b][1]);

return a;
}

全局+1

思考一下二进制意义下的 $+1$ 是如何操作的?

我们只需要从低位到高位找出第一个出现的 $0$ ,将其变为 $1$ ,之后把这个位置后面的 $1$ 都变为 $0$ 即可。

举个几个例子:

而对应到 $\text{01-Trie}$ 上就是交换左右儿子,之后顺着交换后的左儿子向下递归操作即可。

void add1(int x){
swap(ch[x][0] , ch[x][1]);
if(ch[x][0]) add1(ch[x][0]);
update(x);
}

例题:[联合省选2020 A卷] 树

下面我们以Luogu P6623 [联合省选2020 A卷] 树为例题,实地练习一下 $\text{01-Trie}$ 维护异或和。

题目描述

给你一颗 $n$ 个的点的有根树,节点从 $1$ 开始编号且以 $1$ 为根,每个点有点权,记为 $v[i]$。

设 $x$ 号结点的子树内(包含 $x$ 自身)的所有结点编号为 $c_1, c_2, \dots, c_k$,定义 $x$ 的价值为:

定义节点 $x$ 的价值为:

其中 $d(x, y)$ 表示树上 $x$ 号结点与 $y$ 号结点间唯一简单路径所包含的边数,$d(x, x) = 0$。

请你求出 $\sum_{i=1}^n val(i)$ 的结果。


分析

省选赛场上胡了个换根,然后胡没了。

先考虑一下最无脑的 $10$ 分暴力怎么写:

for(int i = 1; i <= n; i ++){
int cnt = 0 , x = i;
do{
val[x] ^= (cnt + v[i]);
cnt ++;
x = fa[x];
}while(x);
}
ll ans = 0;
for(int i = 1; i <= n; i ++) ans += val[i];
printf("%lld\n" , ans);

很明显,这段代码就是在枚举每个点并计算其对其祖先的贡献。

考虑使用 $\text{01-Trie}$ 来加速这个过程。

我们对每个节点均建立一颗 $\text{01-Trie}$。初始时只存了这个节点的权值。

每一次操作,我们自底而上的合并每个子节点,之后全局 $+1$,统计答案即可。


Code[Luogu P6623]

#include <algorithm>
#include <iostream>
#include <cstring>
#include <iomanip>
#include <cstdlib>
#include <climits>
#include <cstdio>
#include <bitset>
#include <string>
#include <vector>
#include <cctype>
#include <stack>
#include <queue>
#include <deque>
#include <cmath>
#include <map>
#include <set>

#define I inline
#define fi first
#define se second
#define pb push_back
#define MP make_pair
#define PII pair<int , int>
#define PSL pair<string , long long>
#define PLL pair<long long , long long>
#define all(x) (x).begin() , (x).end()
#define copy(a , b) memcpy(a , b , sizeof(b))
#define clean(a , b) memset(a , b , sizeof(a))
#define rep(i , l , r) for (int i = (l); i <= (r); i ++)
#define per(i , r , l) for (int i = (r); i >= (l); i --)
#define PE(i , x) for(int i = head[x]; i; i = edge[i].last)
#define DEBUG(x) std :: cerr << #x << '=' << x << std :: endl

typedef long long ll;
typedef unsigned long long ull;

template <class T>
inline void read(T &x) {
char c = getchar(), f = 0; x = 0;
while (!isdigit(c)) f = (c == '-'), c = getchar();
while (isdigit(c)) x = x * 10 + c - '0', c = getchar();
x = f ? -x : x;
}

using namespace std;

const int N = 10000 + 5;
const int M = 530000 + 5;
const int HA = 998244353;

int n , val[M];

struct Edge{
int to , last;
}edge[M << 1];

int edge_num;
int head[M << 1];

I void add_edge(int from , int to){
edge[++ edge_num] = (Edge){to , head[from]}; head[from] = edge_num;
edge[++ edge_num] = (Edge){from , head[to]}; head[to] = edge_num;
}

namespace Trie{
#define lson(x) ch[x][0]
#define rson(x) ch[x][1]

const int H = 21;

int ch[M * H][2] , w[M * H] , xorv[M * H] , cnt = 0;

int mknode(){
cnt ++ , ch[cnt][0] = ch[cnt][1] = w[cnt] = xorv[cnt] = 0;
return cnt;
}

void update(int x){
w[x] = xorv[x] = 0;
if(lson(x)){
w[x] += w[lson(x)];
xorv[x] ^= (xorv[lson(x)] << 1);
}
if(rson(x)){
w[x] += w[rson(x)];
xorv[x] ^= ((xorv[rson(x)] << 1) | (w[rson(x)] & 1));
}
w[x] &= 1;
}

void insert(int &x , int val , int dep){
if(!x) x = mknode();
if(dep > H) return (void)(w[x] ++);
insert(ch[x][val & 1] , val >> 1 , dep + 1);
update(x);
}

void erase(int x , int val , int dep){
if(dep > 20) return (void)(w[x] --);
erase(ch[x][val & 1] , val >> 1 , dep + 1);
update(x);
}

void add1(int x){
swap(lson(x) , rson(x));
if(lson(x)) add1(lson(x));
update(x);
}

int merge(int a , int b){
if(!a || !b) return a + b;
w[a] = w[a] + w[b];
xorv[a] ^= xorv[b];
lson(a) = merge(lson(a) , lson(b));
rson(a) = merge(rson(a) , rson(b));
return a;
}

#undef lson
#undef rson
}

using namespace Trie;

int rt[M];
ll ans = 0;

void dfs(int x , int fa){
PE(i , x){
int to = edge[i].to;
if(to == fa) continue;
dfs(to , x);
rt[x] = merge(rt[x] , rt[to]);
}
add1(rt[x]);
insert(rt[x] , val[x] , 0);
ans += xorv[rt[x]];
}

int main() {
#ifdef LOCAL
freopen("try.in" , "r" , stdin);
freopen("try1.out" , "w" , stdout);
#endif
read(n);
rep(i , 1 , n) read(val[i]);
rep(i , 2 , n){
int x; read(x);
add_edge(x , i);
}
dfs(1 , 0);
printf("%lld\n" , ans);

return 0;
}

THE END