Problem Description
Bi Luo is a magic boy, he also has a migic tree, the tree has
N
nodes , in each node , there is a treasure, it's value is
V[i]
, and for each edge, there is a cost
C[i]
, which means every time you pass the edge
i
, you need to pay
C[i]
.
You may attention that every
V[i]
can be taken only once, but for some
C[i]
, you may cost severial times.
Now, Bi Luo define
ans[i]
as the most value can Bi Luo gets if Bi Luo starts at node
i
.
Bi Luo is also an excited boy, now he wants to know every
ans[i]
, can you help him?
Input
First line is a positive integer
T(T≤104)
, represents there are
T
test cases.
Four each test:
The first line contain an integer
N
(N≤105)
.
The next line contains
N
integers
V[i]
, which means the treasure’s value of node
i(1≤V[i]≤104)
.
For the next
N−1
lines, each contains three integers
u,v,c
, which means node
u
and node
v
are connected by an edge, it's cost is
c(1≤c≤104)
.
You can assume that the sum of
N
will not exceed
106
.
Output
For the i-th test case , first output Case #i: in a single line , then output
N
lines , for the i-th line , output
ans[i]
in a single line.
Sample Input
1
5
4 1 7 7 7
1 2 6
1 3 1
2 4 8
3 5 2
Sample Output
Case #1:
15
10
14
9
15
一道颇为恶心的树形dp,要考虑的细节挺多的,想了好久死了一堆脑细胞。
首先,理解题意以后,我们可以知道,如果要计算从某一个点出发的最优值,可以进行一次dfs
以1作为根节点为例,图中我们可以先走1-2在回来2-1在去1-3再到3-5,这样可以得到1 2 3 5四个点的价值减去中间经过的边
可以看出一点,从1出发经过很多的分支,最后一条分支是不需要再回来的。
所以进行树形dp,需要记录三个值,从这个节点向下每次都回来的最优值g[x],
从这个节点向下最后一次不会来的最优值和次优值dp[x][0]和dp[x][1],顺便记录不会来的是哪一条边f
这样,第一次以1为根的dp可以求出dp[1][0]为1节点的答案,接下来通过这个推导下面相邻节点的答案。
对于x的某个孩子来说,有两种情况,
一种是它向下走不回来,显然这在这在之前的dp中可以得到。
另一种是向上走,这就和x节点的最优值选了那条边有关,如果是选择当前边,那么加上之前的次优值,
不然就加上之前的最优值,当然要先和0比一下,因为可能不走过去。
这样就能推出全部的答案了。
#include<set>
#include<map>
#include<ctime>
#include<cmath>
#include<stack>
#include<queue>
#include<bitset>
#include<cstdio>
#include<string>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<functional>
#define rep(i,j,k) for (int i = j; i <= k; i++)
#define per(i,j,k) for (int i = j; i >= k; i--)
#define lson x << 1, l, mid
#define rson x << 1 | 1, mid + 1, r
#define fi first
#define se second
#define mp(i,j) make_pair(i,j)
#define pii pair<int,int>
using namespace std;
typedef long long LL;
const int low(int x) { return x&-x; }
const double eps = 1e-8;
const int INF = 0x7FFFFFFF;
const int mod = 1e9 + 7;
const int N = 2e5 + 10;
const int read()
{
char ch = getchar();
while (ch<'0' || ch>'9') ch = getchar();
int x = ch - '0';
while ((ch = getchar()) >= '0'&&ch <= '9') x = x * 10 + ch - '0';
return x;
}
int T, n, a[N];
int x, y, z, cas = 0;
int ft[N], nt[N], v[N], u[N], sz;
int g[N], dp[N][2], ans[N][3], f[N][3];
void dfs(int x, int fa)
{
f[x][0] = f[x][1] = -1; g[x] = a[x];
for (int i = ft[x]; i != -1; i = nt[i])
{
if (u[i] == fa) continue;
dfs(u[i], x);
g[x] += max(g[u[i]] - 2 * v[i], 0);
}
for (int i = ft[x]; i != -1; i = nt[i])
{
if (u[i] == fa || dp[u[i]][0] <= v[i]) continue;
int now = g[x] - max(g[u[i]] - 2 * v[i], 0) + max(dp[u[i]][0] - v[i], 0);
if (f[x][0] == -1) f[x][0] = i, dp[x][0] = now;
else if (f[x][1] == -1 || dp[x][1] < now)
{
f[x][1] = i; dp[x][1] = now;
if (dp[x][0] < dp[x][1]) swap(dp[x][0], dp[x][1]), swap(f[x][0], f[x][1]);
}
}
if (f[x][1] == -1) dp[x][1] = g[x];
if (f[x][0] == -1) dp[x][0] = g[x];
}
void get(int x, int fa)
{
for (int i = ft[x]; i != -1; i = nt[i])
{
if (u[i] == fa) continue;
int y = u[i];
ans[y][0] = dp[y][0] + max(g[x] - max(g[y] - 2 * v[i], 0) - 2 * v[i], 0);
ans[y][1] = dp[y][1] + max(g[x] - max(g[y] - 2 * v[i], 0) - 2 * v[i], 0);
f[y][2] = i;
if (g[y] >= 2 * v[i])
{
if (f[x][0] != i) ans[y][2] = ans[x][0] + v[i];
else ans[y][2] = ans[x][1] + v[i];
}
else
{
if (f[x][0] != i) ans[y][2] = g[y] - v[i] + ans[x][0];
else ans[y][2] = ans[x][1] + g[y] - v[i];
}
if (ans[y][1] <= ans[y][2]) swap(ans[y][1], ans[y][2]), swap(f[y][1], f[y][2]);
if (ans[y][0] <= ans[y][1]) swap(ans[y][0], ans[y][1]), swap(f[y][0], f[y][1]);
g[y] += max(g[x] - max(g[y] - 2 * v[i], 0) - 2 * v[i], 0);
get(u[i], x);
}
}
int main()
{
scanf("%d", &T);
while (T--)
{
n = read(); sz = 0;
rep(i, 1, n) a[i] = read(), ft[i] = -1;
rep(i, 1, n - 1)
{
scanf("%d%d%d", &x, &y, &z);
u[sz] = y; nt[sz] = ft[x]; v[sz] = z; ft[x] = sz++;
u[sz] = x; nt[sz] = ft[y]; v[sz] = z; ft[y] = sz++;
}
dfs(1, 0);
ans[1][1] = dp[1][1];
ans[1][0] = dp[1][0];
get(1, 0);
printf("Case #%d:\n", ++cas);
rep(i, 1, n) printf("%d\n", ans[i][0]);
}
return 0;
}
转载请注明原文地址: https://ju.6miu.com/read-1299326.html