洛谷 P1433 吃奶酪
题目传送门
前言
虽然是一道非常基础的状压
d
p
dp
dp(在洛谷上甚至是道驴蹄),但是我调了两个晚自习,最后发现是竟然是状态设计有问题。
所以在此篇题解中,我不但会说出正确做法,还会说出原本的代码错在哪里。
以警醒自己状态设计正确的重要性
错误思路
设
d
p
s
dp_s
dps 表示从
(
0
,
0
)
(0, 0)
(0,0) 走到当前状态
s
s
s 所用的最小路程。
那么就可以得出如下代码(每一段的具体解释见注释):
for (int si = 0; si <= ns; ++si) { // ns 表示状态总数, 就是 (1 << n) - 1
for (int i = 0; i < n; ++i) { // 枚举当前所在节点
if (!(si & (1 << i))) continue;
int sj = si ^ (1 << i); // 上一个状态
if (!sj) { // 如果上一个状态是全 0, 那就说明是从起点转移过来的
dp[si] = min(dp[si], dp[sj] + dis(0, i + 1));
continue;
}
// 若不为全 0, 那就从上一个状态中枚举前一个位置
for (int j = 0; j < n; ++j)
if (sj & (1 << j))
dp[si] = min(dp[si], dp[sj] + dis(i + 1, j + 1));
}
}
你若把代码补全就会发现只能得 66 p t s 66 \ pts 66 pts。
那为什么错呢?
在原本代码中,我们会枚举我们所在的上一个节点(即枚举 j j j)。但是我们似乎忘了,在达到状态 s j sj sj 的最小路径时,它最后到达的一个节点 【不一定】 是我们所枚举的!
假设【当前状态与位置】、【上一个状态】如下(从右往左数):
当前状态
:
0001
1101
,
i
=
4
;
上一个状态
:
0001
0101
当前状态: \ 0001 \ 1101, i = 4; \\ 上一个状态: \ 0001 \ 0101 \\
当前状态: 0001 1101,i=4;上一个状态: 0001 0101
假如我们所枚举的上一个节点
j
=
5
j = 5
j=5,我们由
d
p
s
j
+
d
i
s
(
i
=
4
,
j
=
5
)
dp_{sj} + dis(i = 4, j = 5)
dpsj+dis(i=4,j=5) 转移过来。
但是可能到达状态
0001
0101
0001 \ 0101
0001 0101 的结尾节点是
j
′
=
3
j' = 3
j′=3,此时我们再这样转移显然就是错的。
正确思路
在我们的错误思路中,发现我们的状态还关乎于所处的最后一个节点。所以,增加一维状态:设 d p i , s i dp_{i, si} dpi,si 表示在已经走过 s i si si 的状态下,所处的最后一个节点是 i i i。
那么正确的代码就呼之欲出了:
#include <bits/stdc++.h>
#define mkpr make_pair
#define fir first
#define sec second
using namespace std;
typedef long long ll;
const int maxn = 15 + 7;
const int maxs = (1 << 15) + 7;
const int inf = 0x3f3f3f3f;
int n, ns;
double x[maxn], y[maxn];
double dp[maxn][maxs];
double dis(int a, int b) {
return sqrt((x[a] - x[b]) * (x[a] - x[b]) +
(y[a] - y[b]) * (y[a] - y[b]));
}
void print(int x) {
for (int i = 0; i < n; ++i)
if (x & (1 << i)) putchar('1');
else putchar('0');
}
int main() {
scanf("%d", &n), ns = (1 << n) - 1;
for (int i = 1; i <= n; ++i)
scanf("%lf%lf", x + i, y + i);
for (int i = 0; i <= n; ++i)
for (int si = 0; si <= ns; ++si) dp[i][si] = 2e9;
dp[0][0] = 0;
for (int si = 0; si <= ns; ++si) {
for (int i = 0; i < n; ++i) {
if (!(si & (1 << i))) continue;
int sj = si ^ (1 << i);
if (!sj) {
dp[i][si] = min(dp[i][si], dis(0, i + 1));
continue;
}
for (int j = 0; j < n; ++j)
if (sj & (1 << j))
dp[i][si] = min(dp[i][si], dp[j][sj] + dis(i + 1, j + 1));
}
}
// 结尾可能在任意一个节点, 因此要取最小值
double ans = 2e9;
for (int i = 0; i <= n; ++i)
ans = min(ans, dp[i][ns]);
printf("%.2lf\n", ans);
return 0;
}