CF1748E Yet Another Array Counting Problem
CF1748E Yet Another Array Counting Problem
题目大意
对于一个长度为 n n n的序列 x x x,其在区间 [ l , r ] [l,r] [l,r]的最左端最大值位置为满足 l ≤ i ≤ r l\leq i\leq r l≤i≤r且 x i = max j = l r x j x_i=\max\limits_{j=l}^rx_j xi=j=lmaxrxj的最小的整数 i i i。
给定两个整数 n , m n,m n,m和一个长度为 n n n的序列 a a a,求满足下列条件的序列 b b b的数量:
- 序列 b b b的长度为 n n n,且对于任意整数 i ( 1 ≤ i ≤ n ) i(1\leq i\leq n) i(1≤i≤n)都有 1 ≤ b i ≤ m 1\leq b_i\leq m 1≤bi≤m成立
- 对于任意整数 l , r ( 1 ≤ l ≤ r ≤ n ) l,r(1\leq l\leq r\leq n) l,r(1≤l≤r≤n), a , b a,b a,b在区间的 [ l , r ] [l,r] [l,r]的最左端最大值位置相同
输出满足条件的序列 b b b的数量,对 1 0 9 + 7 10^9+7 109+7取模。
有 t t t组数据。
数据范围
1
≤
t
≤
1
0
3
1\leq t\leq 10^3
1≤t≤103
2
≤
n
,
m
≤
1
0
5
,
∑
n
×
m
≤
1
0
6
2\leq n,m\leq 10^5,\sum n\times m\leq 10^6
2≤n,m≤105,∑n×m≤106
对于所有
i
(
1
≤
i
≤
n
)
i(1\leq i\leq n)
i(1≤i≤n),满足
1
≤
a
i
≤
m
1\leq a_i\leq m
1≤ai≤m
题解
首先我们可以想到,对于一个区间 [ l , r ] [l,r] [l,r]和它的最左端最大值位置 x x x,我们可以把这个区间分成 [ l , x − 1 ] [l,x-1] [l,x−1]和 [ x + 1 , r ] [x+1,r] [x+1,r]两个区间来处理。
设 f ( l , r , x , v ) f(l,r,x,v) f(l,r,x,v)表示区间 [ l , r ] [l,r] [l,r]的最左端最大值位置为 x x x且其值小于等于 v v v时这个区间有多少种放法。那么有
f ( l , r , x , v ) = f ( l , r , x , v − 1 ) + f ( l , x − 1 , v l x , v − 1 ) × f ( x + 1 , r , v r x , v ) f(l,r,x,v)=f(l,r,x,v-1)+f(l,x-1,vl_x,v-1)\times f(x+1,r,vr_x,v) f(l,r,x,v)=f(l,r,x,v−1)+f(l,x−1,vlx,v−1)×f(x+1,r,vrx,v)
因为一个最左端最大值位置只对应一段区间,且一段区间只有一个最左端最大值位置,所以我们可以用 v l x vl_x vlx表示以 x x x为最左端最大值位置的区间 [ l , r ] [l,r] [l,r]中 [ l , x − 1 ] [l,x-1] [l,x−1]的最左端最大值位置, v r x vr_x vrx表示 [ x + 1 , r ] [x+1,r] [x+1,r]的最左端最大值位置。
因为一开始只有一段区间 [ 1 , n ] [1,n] [1,n],而每次加入一个最左端最大值位置只会多算两个区间,所以总共只会有不超过 2 n + 1 2n+1 2n+1个区间,可以 O ( n log n ) O(n\log n) O(nlogn)将所有 v l x , v r x vl_x,vr_x vlx,vrx计算出来。
然后就是计算 f ( l , r , x , v ) f(l,r,x,v) f(l,r,x,v)了。用记忆化搜索,每个位置只会被计算一次,那么这样的时间复杂度为 O ( n m ) O(nm) O(nm)。
总时间复杂度为 O ( ∑ n × m + n log n ) O(\sum n\times m+n\log n) O(∑n×m+nlogn)。
code
#include<bits/stdc++.h>
#define lc k<<1
#define rc k<<1|1
using namespace std;
int T,n,m,now,bz,a[200005],v1[200005],v2[200005],tr[1000005],th[1000005];
long long mod=1000000007;
vector<long long>f[200005];
void build(int k,int l,int r){
if(l==r){
tr[k]=a[l];th[k]=l;
return;
}
int mid=l+r>>1;
build(lc,l,mid);
build(rc,mid+1,r);
if(tr[lc]>=tr[rc]){
tr[k]=tr[lc];th[k]=th[lc];
}
else{
tr[k]=tr[rc];th[k]=th[rc];
}
}
void find(int k,int l,int r,int x,int y){
if(l>=x&&r<=y){
if(now<tr[k]){
now=tr[k];bz=th[k];
}
return;
}
int mid=l+r>>1;
if(x<=mid) find(lc,l,mid,x,y);
if(y>mid) find(rc,mid+1,r,x,y);
}
int pt(int l,int r){
if(l>r) return 0;
now=0;bz=0;
find(1,1,n,l,r);
int x=bz;
v1[x]=pt(l,x-1);v2[x]=pt(x+1,r);
return x;
}
long long gt(int l,int r,int x,int v){
if(l>r||v==0) return 0;
if(l==r) return v;
if(f[x][0]) return f[x][v];
for(int i=1;i<=m;i++){
long long re=1;
if(v1[x]&&v1[x]<x) re=re*gt(l,x-1,v1[x],i-1)%mod;
if(v2[x]&&v2[x]>x) re=re*gt(x+1,r,v2[x],i)%mod;
if(i>1) re=(re+f[x][i-1])%mod;
f[x].push_back(re);
}
f[x][0]=1;
return f[x][v];
}
int main()
{
scanf("%d",&T);
while(T--){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
f[i].push_back(0);
}
build(1,1,n);
int mx=pt(1,n);
printf("%lld\n",gt(1,n,mx,m));
for(int i=1;i<=n;i++){
f[i].clear();
}
}
return 0;
}