QOJ9700 Ying’s Cup(拉格朗日插值优化卷积,背包,二项式反演)
题意
原题链接
简要题意:
给你一棵
n
n
n 个点的树,你需要将
1
∼
n
1 \sim n
1∼n 的排列填到树上的节点中。定义一个点为 局部最小值 满足与它相邻的点上的数都比它大。你需要对
k
=
1
,
2
,
.
.
.
,
n
k = 1,2,...,n
k=1,2,...,n 都输出恰好有
k
k
k 个局部最小值时排列有多少种填法。
1 ≤ n ≤ 500 1 \leq n \leq 500 1≤n≤500。
分析:
首先看到 恰好,可以考虑二项式反演。如何求出在 钦定 了
k
k
k 个局部最小值的情况下的方案数呢?发现如果钦定了一个点
x
x
x 是局部最小值,那么它和相邻点的大小关系就定了,不妨用一条
x
x
x 指向相邻点的有向边表示。
但是这样还是不太好做。发现如果有向边能形成一棵树是比较容易计算方案的。如果对于一个钦定的点能只保留指向儿子的有向边,那么这样就会形成若干有向树,每棵树的方案相互独立乘起来即可。
那么一棵有向树的方案怎么算呢?有一个经典结论:对于一个 n n n 个点的有向森林,填充排列满足任意一个点上的数小于它的儿子的方案数为 n ! × ∏ u 1 s z u n! \times \prod\limits_{u}\frac{1}{sz_u} n!×u∏szu1。
那么钦定点指向父亲的边怎么办?考虑容斥成无向(删去这条边)减去反向边即可。
这样就有了一个 d p dp dp。设 f i , j , k f_{i, j, k} fi,j,k 表示考虑了以 i i i 为根的子树,当前钦定了 j j j 个局部最小值, i i i 所在的有向树大小为 k k k,并且 i i i 没有被钦定的方案数。 g i , j , k g_{i, j, k} gi,j,k 则是表示 i i i 已经被钦定。那么转移就是 二维背包,复杂度 O ( n 4 ) O(n^4) O(n4)。注意如果 i i i 的父亲不指向 i i i,那么需要将 1 k \frac{1}{k} k1 乘到答案中。
注意到对于钦定局部最小值这一维的转移 只涉及到加法卷积,并且最后只需要知道根节点的 d p dp dp 数组,考虑拉插优化:首先将第二三维交换顺序,然后将 f i , k f_{i, k} fi,k 看做一个长度为子树大小的多项式, f i , k , j f_{i, k, j} fi,k,j 是 x j x^j xj 的系数。那么对 j j j 这一维的转移就是两个多项式的卷积。 f r o o t , k f_{root, k} froot,k 肯定是一个 n n n 次多项式,因此枚举 x = 1 ∼ n + 1 x = 1\sim n + 1 x=1∼n+1 求出 n + 1 n + 1 n+1 个点拉插还原系数即可。
枚举 x x x 后每次 d p dp dp 复杂度 O ( n 2 ) O(n^2) O(n2),求点值的复杂度为 O ( n 3 ) O(n^3) O(n3)。最后拉插还原系数可以用 短除法 做到 O ( n 2 ) O(n^2) O(n2),总复杂度 O ( n 3 ) O(n^3) O(n3)。
CODE:
// 首先可以二项式反演把恰好改成钦定
// 然后定一个最小值就能把一个点和它周围点的大小关系定下来
// 可以把大小关系用一条有向边表示
// 发现儿子指向父亲是不好搞的,容斥成没有这条边 - 父亲指向儿子
// 于是就可以得到一个 O(n^4) 的二维背包
// 把dp数组看作系数,插值优化
// 对于某一维只有加法卷积的转移, 可以拉差优化掉这一维
#include<bits/stdc++.h>
#define pb emplace_back
using namespace std;
typedef long long LL;
const int mod = 998244353;
const int N = 510;
int n;
int X[N], Y[N]; // n次多项式, 维护 n + 1 个点
int f[N][N], g[N][N];
int tf[N], tg[N];
int fac[N], inv[N], Inv[N];
int c[N][N], tmp[N], H[N], F[N], G[N];
int sz[N];
vector< int > E[N];
inline int sign(int x) {return (x & 1) ? mod - 1 : 1;}
inline int Pow(int x, int y) {
int res = 1, k = x;
while(y) {
if(y & 1) res = 1LL * res * k % mod;
y >>= 1;
k = 1LL * k * k % mod;
}
return res;
}
void dfs(int x, int fa, int d) { // d 为底数
sz[x] = 1;
f[x][1] = 1; g[x][1] = d;
for(auto v : E[x]) {
if(v == fa) continue;
dfs(v, x, d);
for(int i = 1; i <= sz[x]; i ++ ) {
for(int j = 1; j <= sz[v]; j ++ ) {
tf[i] = (tf[i] + 1LL * f[x][i] * f[v][j] % mod * Inv[j] % mod) % mod;
tf[i] = (tf[i] + 1LL * f[x][i] * g[v][j] % mod * Inv[j] % mod) % mod;
tf[i + j] = (tf[i + j] + 1LL * f[x][i] * g[v][j] % mod * Inv[j] % mod * (mod - 1) % mod) % mod;
tg[i + j] = (tg[i + j] + 1LL * g[x][i] * f[v][j] % mod * Inv[j] % mod) % mod;
}
}
sz[x] += sz[v];
for(int i = 1; i <= sz[x]; i ++ ) {
f[x][i] = tf[i], g[x][i] = tg[i];
tf[i] = 0; tg[i] = 0;
}
}
}
void lag() { // 把 n + 1 个点的系数插出来
for(int i = 0; i <= n; i ++ ) F[i] = 0;
for(int i = 0; i <= n + 1; i ++ ) tmp[i] = 0;
tmp[0] = 1;
for(int i = 1; i <= n + 1; i ++ ) {
for(int j = n + 1; j >= 0; j -- ) {
tmp[j + 1] = (tmp[j + 1] + tmp[j]) % mod;
tmp[j] = (1LL * tmp[j] * X[i] % mod * (mod - 1) % mod) % mod;
}
}
for(int i = 1; i <= n + 1; i ++ ) {
int a = 1;
a = 1LL * inv[i - 1] * inv[n + 1 - i] % mod * sign(n + 1 - i) % mod * Y[i] % mod;
int lst = tmp[n + 1];
for(int j = n; j >= 0; j -- ) {
F[j] = (F[j] + 1LL * lst * a % mod) % mod;
lst = (tmp[j] + 1LL * lst * X[i] % mod) % mod;
}
}
}
int main() {
for(int i = 0; i < N; i ++ ) {
for(int j = 0; j <= i; j ++ ) {
if(!j) c[i][j] = 1;
else c[i][j] = (c[i - 1][j - 1] + c[i - 1][j]) % mod;
}
}
scanf("%d", &n);
fac[0] = 1; for(int i = 1; i <= n; i ++ ) fac[i] = 1LL * fac[i - 1] * i % mod;
inv[n] = Pow(fac[n], mod - 2LL);
for(int i = n - 1; i >= 0; i -- ) inv[i] = 1LL * inv[i + 1] * (i + 1) % mod;
Inv[0] = 1; for(int i = 1; i <= n; i ++ ) Inv[i] = Pow(1LL * i, mod - 2LL);
for(int i = 1; i < n; i ++ ) {
int u, v; scanf("%d%d", &u, &v);
E[u].pb(v); E[v].pb(u);
}
for(int x = 1; x <= n + 1; x ++ ) {
dfs(1, 0, 1LL * x);
for(int i = 1; i <= n; i ++ ) {
X[x] = x;
Y[x] = (Y[x] + 1LL * (f[1][i] + g[1][i]) * Inv[i] % mod) % mod;
}
}
lag();
for(int i = 0; i <= n; i ++ ) F[i] = 1LL * F[i] * fac[n] % mod;
for(int i = 0; i <= n; i ++ ) {
for(int j = i; j <= n; j ++ ) {
G[i] = (G[i] + 1LL * c[j][i] * sign(j - i) % mod * F[j] % mod) % mod;
}
}
for(int i = 1; i <= n; i ++ ) printf("%d\n", G[i]);
return 0;
}