题目
代码
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 1e5 + 10;
const int M = 2 * N;
int p[N][18], d[N], a[N];
ll dis[N][18]; //注意这里要开long long
int h[N], e[M], ne[M], idx, w[M];
int n, k;
void add(int a, int b, int c)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx, w[idx++] = c;
}
void dfs(int u)
{
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (d[j])
continue;
d[j] = d[u] + 1;
p[j][0] = u;
dis[j][0] = w[i];
for (int k = 1; k <= 17; k++)
{
p[j][k] = p[p[j][k - 1]][k - 1];
dis[j][k] = dis[j][k - 1] + dis[p[j][k - 1]][k - 1];
}
dfs(j);
}
}
ll lca(int a, int b)
{
ll retv = 0;
if (d[a] < d[b])
swap(a, b);
for (int i = 17; i >= 0; i--)
{
if (d[p[a][i]] >= d[b])
{
retv += dis[a][i];
a = p[a][i];
}
}
if (a == b)
return retv;
for (int i = 17; i >= 0; i--)
{
if (p[a][i] != p[b][i])
{
retv += dis[a][i];
retv += dis[b][i];
a = p[a][i];
b = p[b][i];
}
}
retv += dis[a][0];
retv += dis[b][0];
return retv;
}
int main()
{
memset(h, -1, sizeof h);
cin >> n >> k;
for (int i = 1; i < n; i++)
{
int a, b, c;
cin >> a >> b >> c;
add(a, b, c);
add(b, a, c);
}
d[1] = 1;
dfs(1);
ll tmp = 0;
cin >> a[1];
for (int i = 2; i <= k; i++)
{
cin >> a[i];
tmp += lca(a[i - 1], a[i]);
}
for (int i = 1; i <= k; i++)
{
ll ans = tmp;
if (i == 1)
ans -= lca(a[1], a[1 + 1]);
else if (i == k)
ans -= lca(a[k - 1], a[k]);
else
{
ans -= lca(a[i - 1], a[i]);
ans -= lca(a[i], a[i + 1]);
ans += lca(a[i - 1], a[i + 1]);
}
cout << ans << ' ';
}
return 0;
}