HDU5909 Tree Cutting(FWT)
HDU5909 Tree Cutting
题目大意
有一棵有 n n n个点的树 T T T,每个节点有一个权值 v i v_i vi。定义一棵树的权值为其所有节点的权值的异或和。对于每个数 k ∈ [ 0 , m ) k\in[0,m) k∈[0,m),求这棵树 T T T的权值为 k k k的子树的个数。
T T T的子树是 T T T的子图,也是一棵树。
输出答案对 1 0 9 + 7 10^9+7 109+7取模后的值。
有 T T T组数据。
数据范围
1 ≤ T ≤ 10 , 1 ≤ n ≤ 1000 , 1 ≤ m ≤ 2 10 , 0 ≤ v i < m 1\leq T\leq 10,1\leq n\leq 1000,1\leq m\leq 2^{10},0\leq v_i<m 1≤T≤10,1≤n≤1000,1≤m≤210,0≤vi<m
题解
考虑DP。设 f i , j f_{i,j} fi,j表示在包含 i i i的子树中,权值为 j j j的子树。
那么,遍历这棵树,对于 u u u的每一个儿子 v v v,有转移式
f u , k = ∑ i ⊕ j = k f u , i × f v , j f_{u,k}=\sum\limits_{i\oplus j=k}f_{u,i}\times f_{v,j} fu,k=i⊕j=k∑fu,i×fv,j
注意每次转移到时候,儿子 v v v的DP值 f v , 0 f_{v,0} fv,0要加1,代表儿子 v v v及以下的部分不选。
这就可以用 F W T FWT FWT的异或卷积来做了。
最后统计答案的时候,记得要把每个 f i , 0 f_{i,0} fi,0上多加点1减去。那么,权值为 k k k的子树的个数为
a n s k = ∑ i = 1 n f k ans_k=\sum\limits_{i=1}^nf_k ansk=i=1∑nfk
时间复杂度为 O ( T n m log m ) O(Tnm\log m) O(Tnmlogm)。
code
#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
int t,n,m,tot,d[2005],l[2005],r[2005];
int ans[1<<10],f[1005][1<<10];
const int mod=1000000007,ny2=500000004;
int in(){
int re=0;
char ch=getchar();
while(ch<'0'||ch>'9') ch=getchar();
while(ch>='0'&&ch<='9'){
re=re*10+ch-'0';
ch=getchar();
}
return re;
}
void add(int xx,int yy){
l[++tot]=r[xx];d[tot]=yy;r[xx]=tot;
}
void fwt(int *w,int fl){
for(int s=2;s<=m;s<<=1){
int mid=s>>1;
for(int v=0;v<m;v+=s){
for(int i=0;i<mid;i++){
int t1=w[v+i],t2=w[v+mid+i];
w[v+i]=1ll*(t1+t2)%mod;w[v+mid+i]=1ll*(t1-t2+mod)%mod;
if(fl==-1){
w[v+i]=1ll*w[v+i]*ny2%mod;
w[v+mid+i]=1ll*w[v+mid+i]*ny2%mod;
}
}
}
}
}
void dfs(int u,int fa){
fwt(f[u],1);
for(int i=r[u];i;i=l[i]){
if(d[i]==fa) continue;
dfs(d[i],u);
for(int j=0;j<m;j++) f[u][j]=1ll*f[u][j]*f[d[i]][j]%mod;
}
fwt(f[u],-1);
f[u][0]=(f[u][0]+1)%mod;
fwt(f[u],1);
}
int main()
{
t=in();
while(t--){
n=in();m=in();
for(int i=1;i<=n;i++){
++f[i][in()];
}
tot=0;
for(int i=1,x,y;i<n;i++){
x=in();y=in();
add(x,y);add(y,x);
}
dfs(1,0);
for(int i=1;i<=n;i++){
fwt(f[i],-1);
f[i][0]=(f[i][0]-1+mod)%mod;
}
for(int i=1;i<=n;i++){
for(int j=0;j<m;j++){
ans[j]=(ans[j]+f[i][j])%mod;
}
}
for(int i=0;i<m;i++){
printf("%d",ans[i]);
if(i<m-1) printf(" ");
}
printf("\n");
for(int i=1;i<=n;i++) r[i]=0;
for(int i=0;i<m;i++) ans[i]=0;
for(int i=1;i<=n;i++){
for(int j=0;j<m;j++){
f[i][j]=0;
}
}
}
return 0;
}