POJ 1795 DNA Laboratory 状态压缩DP(旅行商问题)
一、题目大意
我们有N个字符串,每个长度介于1到100,现要求构建一个组合串,使得所有字符串都为组合串的子串,找到长度最小的组合串,如果有多种可能,输出字典序排序最小的组合串。
二、解题思路
我们来回忆下状压DP解决旅行商问题,DP[S][v]代表已经走过的点为S,并从v开始走完剩余节点的最小距离。
其实你仔细思考,发现过滤掉那些 互为子串的字符串,之后剪掉首尾相接的公共部分,其实最终的组合串其实就是这些字符串互相拼接,而本题目其实就是要找出最佳的拼接顺序,就和旅行商问题异曲同工。
那么本题差不多,DP[S][v]代表已经包含字符串的集合为S,且结尾的字符串为 v时,去连接剩余字符串的最小长度(同时也记录连接的下一个字符串)。
首先预处理:1、把互为子串的字符串只留下长的那个。2、对于多个字符串相等只留下其中一个。
然后我们可知 dp[全集][ 0.. n] 都为0。
之后依次循环大小为 n-1,n-2...1的集合,执行如下递推式
1、如果 dp[S | 1 << u][u] + len[u] - v的尾和u的首的公共部分长度 < dp[S][v]
则 dp[S][v] = dp[S | 1 << u][u] + len[u] - v的尾和u的首的公共部分长度
2、如果 dp[S | 1 << u][u] + len[u] - v的尾和u的首的公共部分长度 == dp[S][v]
则对比较两种情况的字符串,如果以u开头字典序更小,dp[S][v] = dp[S | 1 << u][u] + len[u] - v的尾和u的首的公共部分长度。
我们可以在dp数组保存两个变量,下一次连接的字符串和拼接剩余所有字符串的长度。
这样比较字符字典序的时候,一开始设置 nxt = u,used = S,然后使 used = used | 1<< u,找到 dp[used][nxt][0]就是下一个字符串。下一个字符串的有效起始下标就是编号为 nxt 的字符串的尾 和 编号为 dp[used][nxt][0]字符串的首的公共长度。之后更新 nxt = dp[used][nxt][0]即可。
这样就可以通过某个位置,计算出全部的链路,进行字典序比较。
DP计算结束后,需要记录答案的起始字符串下标 ans,和答案长度,然后从0到n定义变量 i 循环。
如果 dp[1 << i] +len[i] < ansLen
则 ansLen=dp[1 << i] +len[i], ans = i;
如果 dp[1 << i] +len[i] == ansLen 且 i 编号的字符串的字典序 比 ans编号的字符串小
则 ansLen=dp[1 << i] +len[i], ans = i;
(过滤掉了互为子串的情况,则只比较开头的字符串即可)
我们定义数组 merg[i][j]代表字符串 i 的尾 和 字符串 j 的首的公共部分长度。
然后输出答案的时候,可以定义 nxt = ans,定义used = 0,然后定义下标 i,执行 ansLen次循环,每次循环输出一个字符。
for k in [0,ansLen) {
if i >= len[nxt] {
used = used | 1<< nxt, i = merg[ nxt ][ dp[ used ][ nxt ][ 0 ] ]; nxt = dp[ used ][ nxt ][ 0 ];
}
输出 nxt下标的字符串的第 i 个字符,不要换行。
i++
}
最后需要输出两个 \n\n,不然 presentation error。
备注:本题目我挂了很多次,后来不使用 %s输入字符,不使用 strcmp 和 strlen,而是使用 %c输入,自己计数,最终过了,不清楚是不是因为不能用 %s、strcmp 和 strlen。
三、代码
#include <iostream>
using namespace std;
const int MAX_N = 15, MAX_LEN = 110, INF = 0x3f3f3f3f;
char tmpStr[MAX_N][MAX_LEN], dat[MAX_N][MAX_LEN];
int len[MAX_N], tmp[MAX_N], dp[1 << MAX_N][MAX_N][2], merg[MAX_N][MAX_N], n, num, ans, ansLen, all;
bool need[MAX_N];
void putAns()
{
cout << "Scenario #" << num << ":" << endl;
int used = 0;
int nxt = ans;
int j = 0;
for (int k = 0; k < ansLen; k++)
{
if (j >= len[nxt])
{
used = used | 1 << nxt;
j = merg[nxt][dp[used][nxt][0]];
nxt = dp[used][nxt][0];
}
printf("%c", dat[nxt][j]);
j++;
}
printf("\n");
printf("\n");
}
bool compareStr(int prv, int a, int b, int used, int _len)
{
int i = merg[prv][a], j = merg[prv][b], nxt1 = a, nxt2 = b, used1 = used, used2 = used;
char c1, c2;
for (int k = 0; k < _len; k++)
{
if (i >= len[nxt1])
{
used1 = used1 | 1 << nxt1;
i = merg[nxt1][dp[used1][nxt1][0]];
nxt1 = dp[used1][nxt1][0];
}
if (j >= len[nxt2])
{
used2 = used2 | 1 << nxt2;
j = merg[nxt2][dp[used2][nxt2][0]];
nxt2 = dp[used2][nxt2][0];
}
if (dat[nxt1][i] != dat[nxt2][j])
{
return dat[nxt1][i] < dat[nxt2][j];
}
i++;
j++;
}
return false;
}
void handleStr(int used, int v, int u)
{
if (dp[used | 1 << u][u][1] + len[u] - merg[v][u] > dp[used][v][1])
{
return;
}
if (dp[used | 1 << u][u][1] + len[u] - merg[v][u] == dp[used][v][1] && !compareStr(v, u, dp[used][v][0], used, dp[used][v][1]))
{
return;
}
dp[used][v][0] = u;
dp[used][v][1] = dp[used | 1 << u][u][1] + len[u] - merg[v][u];
}
void handle(int size)
{
int used = 0;
for (int i = 0; i < size; i++)
{
used = used | 1 << i;
}
while (used < all)
{
for (int v = 0; v < n; v++)
{
for (int u = 0; u < n; u++)
{
if ((used >> v & 1) && !(used >> u & 1))
{
handleStr(used, v, u);
}
}
}
int x = used & -used;
int y = used & ~(used + x);
used = used + x + (y / x / 2);
}
}
bool compareAns(int i)
{
for (int k = 0; k < min(len[i], len[ans]); k++)
{
if (dat[i][k] != dat[ans][k])
{
return dat[i][k] < dat[ans][k];
}
}
return false;
}
void doDp()
{
for (int i = 0; i < (1 << MAX_N); i++)
{
for (int j = 0; j < MAX_N; j++)
{
dp[i][j][0] = INF;
dp[i][j][1] = INF;
}
}
all = (1 << n) - 1;
for (int i = 0; i < n; i++)
{
dp[all][i][0] = 0;
dp[all][i][1] = 0;
}
for (int i = n - 1; i > 0; i--)
{
handle(i);
}
ansLen = INF;
for (int i = 0; i < n; i++)
{
if (dp[1 << i][i][1] + len[i] < ansLen)
{
ans = i;
ansLen = dp[1 << i][i][1] + len[i];
}
else if (dp[1 << i][i][1] + len[i] == ansLen && compareAns(i))
{
ans = i;
ansLen = dp[1 << i][i][1] + len[i];
}
}
}
void mergeStr()
{
for (int i = 0; i < MAX_N; i++)
{
for (int j = 0; j < MAX_N; j++)
{
merg[i][j] = 0;
}
}
for (int v = 0; v < n; v++)
{
for (int u = 0; u < n; u++)
{
for (int st = len[u] - 1; st >= 0 && len[u] - st <= len[v]; st--)
{
for (int k = 0; st + k < len[u]; k++)
{
if (dat[v][k] != dat[u][st + k])
{
break;
}
if (st + k + 1 == len[u])
{
merg[u][v] = max(merg[u][v], len[u] - st);
}
}
}
}
}
}
void filterInclude()
{
for (int v = 0; v < n; v++)
{
tmp[v] = len[v];
for (int k = 0; k < len[v]; k++)
{
tmpStr[v][k] = dat[v][k];
}
}
int v = 0;
for (int i = 0; i < n; i++)
{
if (!need[i])
{
continue;
}
for (int k = 0; k < tmp[i]; k++)
{
dat[v][k] = tmpStr[i][k];
}
len[v] = tmp[i];
v++;
}
n = v;
}
void findInclude()
{
fill(need, need + MAX_N, true);
for (int v = 0; v < n; v++)
{
for (int u = 0; u < n; u++)
{
if (!need[v] || !need[u] || v == u || len[v] > len[u])
{
continue;
}
for (int st = 0; st < len[u] && st + len[v] <= len[u]; st++)
{
for (int k = 0; k < len[v]; k++)
{
if (dat[v][k] != dat[u][st + k])
{
break;
}
if (k + 1 == len[v])
{
need[v] = false;
}
}
}
}
}
}
void input()
{
char c;
scanf("%d\n", &n);
for (int i = 0; i < n; i++)
{
len[i] = 0;
while (true)
{
scanf("%c", &c);
if (c == '\n')
{
break;
}
else
{
dat[i][len[i]] = c;
len[i] = len[i] + 1;
}
}
}
}
int main()
{
int T = 0;
scanf("%d", &T);
for (num = 1; num <= T; num++)
{
input();
findInclude();
filterInclude();
mergeStr();
doDp();
putAns();
}
return 0;
}