Codeforces 1276D Tree Elimination 题解
思路
首先可以发现,对序列计数就是对擦除标记的方式计数.
对于一个点 ,考察 邻域中的两个点 ,我们记 表示边 先于边 出现, 表示边 后于边 出现.
进行一个 DP.设 表示考虑了 子树, 没有被擦除 / 被 的儿子擦除 / 被 擦除 / 被 的儿子擦除,其中 代表点 的父亲.
分 类转移:
- 考虑转移 情况.这时 的儿子必定都被擦除,且儿子 的擦除时间一定不后于 这条边被考虑.这是因为考虑 这条边时,为了保全 不被擦除, 一定会被擦.那么有转移
- 考虑转移 情况.我们可以从儿子中钦定一个 的 作为擦掉 的点.组合其他点的方案时,如果存在 的儿子 不被擦除或者在考虑边 之后被擦,那么 会被 擦而不是被 擦,这样的状态不合法.由于在考虑 时 被擦了,那么 的儿子显然就不能被 擦了.那么有转移
- 考虑转移 情况.由于 被 擦了,那么所有 的儿子 都不能擦 ,所有 的儿子 都不能被 擦.有转移
- 考虑转移 情况.于 情况类似,枚举 的儿子钦定擦除即可.有转移
暴力转移是 的,无法通过本题.瓶颈在于 , 两种状态的转移.注意到如果按照加入顺序访问出边,后面那两个 可以通过预处理前后缀积弄掉.时间复杂度 .
代码
#include <cstdio>
#include <utility>
#include <vector>
inline int rd() {
int x = 0, f = 1, c = getchar();
while (((c - '0') | ('9' - c)) < 0)
f = c != '-', c = getchar();
while (((c - '0') | ('9' - c)) > 0)
x = x * 10 + c - '0', c = getchar();
return f ? x : -x;
}
using pii = std::pair<int, int>;
using i64 = long long;
const i64 P = 998244353;
const int N = 2e5;
int n;
std::vector<pii> g[N + 5];
int fid[N + 5];
void pre(int u, int fa) {
for (auto [v, id] : g[u]) {
if (v == fa) continue;
fid[v] = id, pre(v, u);
}
}
i64 f[N + 5][4];
void dfs(int u, int fa) {
for (auto [v, _] : g[u]) if (v != fa) dfs(v, u);
static int V[N + 5]; int tp = 0;
for (auto [v, _] : g[u]) if (v != fa) V[++tp] = v;
static i64 pre[N + 5], suf[N + 5]; pre[0] = suf[tp + 1] = 1;
for (int i = 1; i <= tp; i++) {
int v = V[i];
pre[i] = (f[v][1] + f[v][2]) % P;
suf[i] = (f[v][0] + f[v][1] + f[v][3]) % P;
}
for (int i = 1; i <= tp; i++) (pre[i] *= pre[i - 1]) %= P;
for (int i = tp; i >= 1; i--) (suf[i] *= suf[i + 1]) %= P;
f[u][0] = 1, f[u][2] = 1;
for (int i = 1; i <= tp; i++) {
int v = V[i];
(f[u][0] *= (f[v][1] + f[v][2]) % P) %= P;
if (fid[v] < fid[u]) {
(f[u][2] *= (f[v][1] + f[v][2]) % P) %= P;
(f[u][1] += (f[v][0] + f[v][3]) % P * pre[i - 1] % P * suf[i + 1] % P) %= P;
} else {
(f[u][2] *= (f[v][0] + f[v][1] + f[v][3]) % P) %= P;
(f[u][3] += (f[v][0] + f[v][3]) % P * pre[i - 1] % P * suf[i + 1] % P) %= P;
}
}
}
int main() {
n = rd();
for (int i = 1; i <= n - 1; i++) {
int u = rd(), v = rd();
g[u].emplace_back(v, i);
g[v].emplace_back(u, i);
}
fid[1] = n, pre(1, 0), dfs(1, 0);
printf("%lld\n", (f[1][0] + f[1][1]) % P);
return 0;
}
参考
qazswedx, DP 选讲
xht, Codeforces Round #606 (Div. 1, based on Technocup 2020 Elimination Round 4) 题解