树链剖分(重链剖分)
树链剖分的核心思想就是将一棵树剖分成一条一条的链
因为树不好处理 但链比较好处理
为了学会它 我们先要学会树上dfs(深度优先搜索) 然后就没了(雾)
Because 树链剖分需要用到两个dfs
哦对了 我们还要了解以下的知识点
1.子树大小
就是一个节点的子树的节点个数(包括它自己)
2.重儿子
一个节点的所有儿子中子树大小最大的儿子
3.轻儿子
除了重儿子之外的儿子(包括根节点)
4.重边
两个相邻的重儿子连成的边
5.重链
重边连成的链(开头是轻儿子)
OK,了解完了
启动!
dfs1:
目标是求出以下数据
1: 节点的父节点
2: 节点的深度 也就是它到根节点的距离
3: 节点的重儿子
4: 节点的子树大小
思路十分的简单 直接上代码
void dfs1(int x)
{
deep[x]=deep[fa[x]]+1;
si[x]=1;//它自己也是
for(int i=0;i<a[x].size();i++)
{
if(a[x][i]==fa[x]) continue;
fa[a[x][i]]=x;
dfs1(a[x][i]);
si[x]+=si[a[x][i]];
if(si[a[x][i]]>si[son[x]]) son[x]=a[x][i];
}
}
dfs2:
目标是求出以下数据
1: 节点的所在的重链的顶端节点
2: 节点的新编号
3: 节点的新编号的值
其中 新编号的原则是先走重儿子 再走轻儿子(理由下次一定)
比如
懂了吗? 我相信你肯定懂了
思路也十分的简单 直接上代码
void dfs2(int x,int topf)
{
id[x]=++cnt;
top[x]=topf;
b[cnt]=w[x];
if(!son[x]) return;
dfs2(son[x],topf);
for(int i=0;i<a[x].size();i++)
{
if(a[x][i]==fa[x]||a[x][i]==son[x])
{
continue;
}
dfs2(a[x][i],a[x][i]);
}
}
我们发现先走重儿子的话 就可以让重链上的点的新编号是连续的 也就形成了一条链
至此 树链剖分就结束了
那你一定很想问 这玩意有什么用呢?
题目
我们结合题目来看(其实是模板)
前置芝士:线段树
让我们处理4种操作
我们先来看操作3和操作4
通过观察图片可以看出
一个节点的子树的编号肯定是连续的
设子树的根是
左端点是的编号是 右端点的编号是
所以3操作就是将到加上
4操作就是求出到的值的和
即区间修改 ,区间查询
我们当然会想到线段树啦
直接拿下
之后我们来看1,2操作
先来看一看图
比如我们要从 7 到 4
我们可以先判断它俩的top是不是一个东西 发现不是
发现4的深度比7小 这怎么可以呢?
我们交换一下 变成从 4 到 7
然后把4跳到它的top的父节点
也就是0(1的父节点是0)
并加上 1 到 4 的值 可以用线段树处理
之后在循环上述操作
直到两者的top相等
上代码
int getsum1(int x,int y)
{
int sum=0;
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]])
{
swap(x,y);
}
sum=sum+query(1,1,n,id[top[x]],id[x]);
x=fa[top[x]];
}
if(deep[x]>deep[y])
{
swap(x,y);
}
sum+=query(1,1,n,id[x],id[y]);
return sum;
}
修改也差不多 直接上代码
void addsum1(int x,int y,int z)
{
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]])
{
swap(x,y);
}
add(1,1,n,id[top[x]],id[x],z);
x=fa[top[x]];
}
if(deep[x]>deep[y])
{
swap(x,y);
}
add(1,1,n,id[x],id[y],z);
}
完整版:
#include<bits/stdc++.h>
#define int long long
using namespace std;
int n,cnt,m,r,mod;
int b[110000],top[110000],w[110000],deep[110000],id[110000],fa[110000],si[110000],son[110000],tr[410000],tag[410000];
vector<int>a[110000];
void dfs1(int x)
{
deep[x]=deep[fa[x]]+1;
si[x]=1;//它自己也是
for(int i=0;i<a[x].size();i++)
{
if(a[x][i]==fa[x]) continue;
fa[a[x][i]]=x;
dfs1(a[x][i]);
si[x]+=si[a[x][i]];
if(si[a[x][i]]>si[son[x]]) son[x]=a[x][i];
}
}
void dfs2(int x,int topf)
{
id[x]=++cnt;
top[x]=topf;
b[cnt]=w[x];
if(!son[x]) return;
dfs2(son[x],topf);
for(int i=0;i<a[x].size();i++)
{
if(a[x][i]==fa[x]||a[x][i]==son[x])
{
continue;
}
dfs2(a[x][i],a[x][i]);
}
}
void build(int k,int l,int r)
{
if(l==r)
{
tr[k]=b[l];
return;
}
int mid=(l+r)>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
tr[k]=tr[k<<1]+tr[k<<1|1];
}
void push_down(int k,int l,int r)
{
if(tag[k])
{
int mid=(l+r)>>1;
tag[k<<1]+=tag[k];
tag[k<<1|1]+=tag[k];
tr[k<<1]+=(mid-l+1)*tag[k];
tr[k<<1|1]+=(r-mid)*tag[k];
tag[k]=0;
}
}
void add(int k,int l,int r,int q,int p,int d)
{
if(q<=l&&p>=r)
{
tag[k]+=d;
tr[k]+=(r-l+1)*d;
return;
}
push_down(k,l,r);
int mid=(l+r)>>1;
if(mid>=q)
{
add(k<<1,l,mid,q,p,d);
}
if(mid<p)
{
add(k<<1|1,mid+1,r,q,p,d);
}
tr[k]=(tr[k<<1]+tr[k<<1|1]);
}
int query(int k,int l,int r,int q,int p)
{
if(q<=l&&p>=r)
{
return tr[k];
}
push_down(k,l,r);
int mid=(l+r)>>1,ssum=0;
if(mid>=q)
{
ssum+=query(k<<1,l,mid,q,p);
}
if(mid<p)
{
ssum+=query(k<<1|1,mid+1,r,q,p);
}
return ssum;
}
int getsum1(int x,int y)
{
int sum=0;
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]])
{
swap(x,y);
}
sum=sum+query(1,1,n,id[top[x]],id[x]);
x=fa[top[x]];
}
if(deep[x]>deep[y])
{
swap(x,y);
}
sum+=query(1,1,n,id[x],id[y]);
return sum;
}
int getsum2(int x)
{
return query(1,1,n,id[x],id[x]+si[x]-1);
}
void addsum1(int x,int y,int z)
{
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]])
{
swap(x,y);
}
add(1,1,n,id[top[x]],id[x],z);
x=fa[top[x]];
}
if(deep[x]>deep[y])
{
swap(x,y);
}
add(1,1,n,id[x],id[y],z);
}
void addsum2(int x,int y)
{
add(1,1,n,id[x],id[x]+si[x]-1,y);
}
signed main()
{
scanf("%lld%lld%lld%lld",&n,&m,&r,&mod);
for(int i=1;i<=n;i++)
{
scanf("%lld",&w[i]);
}
for(int i=1;i<n;i++)
{
int x,y;
scanf("%lld%lld",&x,&y);
a[x].push_back(y);
a[y].push_back(x);
}
dfs1(r);
dfs2(r,r);
build(1,1,n);
while(m--)
{
int op;
scanf("%lld",&op);
if(op==1)
{
int x,y,z;
scanf("%lld%lld%lld",&x,&y,&z);
addsum1(x,y,z);
}
if(op==2)
{
int x,y;
scanf("%lld%lld",&x,&y);
printf("%lld\n",getsum1(x,y)%mod);
}
if(op==3)
{
int x,y;
scanf("%lld%lld",&x,&y);
addsum2(x,y);
}
if(op==4)
{
int x;
scanf("%lld",&x);
printf("%lld\n",getsum2(x)%mod);
}
}
return 0;
}