第一次做树形dp的题就遇到这么恶心的题,赛后看着别人的代码研究了一个下午终于弄出来了,细节不是一般的多,思路要非常非常清晰,不然就gg。下面说下思路:
我们默认1为根节点跑一次dfs,用g数组表示以该节点为根向下走返回到该节点时的最大值,明显有转移方程:g[根] += max(g[子] -边权值*2, 0);从某个点出发获得最大值肯定最后的落点不会在原来的点上,所以我们用dp来存该点最后能获得的最大值和次大值,转移方程详情看代码。最后就知道根节点也就是1的答案就是dp[1][0]。
我们现在只知道1的答案,但是我们不能用一个循环重复上面步骤,时间复杂度很高,所以最后要靠根节点和子节点的关系再跑一次dfs,把其他点的答案全部推出来。子节点最大值的情况只有两种,子节点往上走然后回到原点,再往下走到某个地方达到最大,或者是反过来往下走然后回到原点往上走到最大。往下走到最大就是我们已经记录过的数dp[子][0],往下走回到原点的最大就是g[子],往上走我们可以看出根节点和子节点的关系,详情请看代码:
#include<iostream> #include<stack> #include<cstring> #include<map> #include<string> #include<queue> #include<algorithm> #include<cstdio> #include<set> #include<cmath> using namespace std; #define maxn 100005 #define inf 0x3f3f3f3f typedef long long LL; int value[maxn], head[maxn], dp[maxn][2], edge[maxn][3], g[maxn], ans[maxn][3]; struct node{ int u, value, next; }p[maxn << 1]; const int read() { char ch = getchar(); while (ch<'0' || ch>'9') ch = getchar(); int x = ch - '0'; while ((ch = getchar()) >= '0'&&ch <= '9') x = x * 10 + ch - '0'; return x; } void dfs(int x, int fa){ g[x] = value[x]; for (int i = head[x]; i != -1; i = p[i].next){ if (p[i].u == fa) continue; dfs(p[i].u, x); g[x] += max(g[p[i].u] - (p[i].value << 1), 0); } edge[x][0] = edge[x][1] = -1; for (int i = head[x]; i != -1; i = p[i].next){ if (p[i].u == fa || (dp[p[i].u][0] - p[i].value <= 0)) continue; int now = g[x] - max(g[p[i].u] - (p[i].value << 1), 0) + max(dp[p[i].u][0] - p[i].value, 0); if (edge[x][0] == -1){ edge[x][0] = i; dp[x][0] = now; } else if (edge[x][1] == -1 || now > dp[x][1]){ edge[x][1] = i; dp[x][1] = now; if (dp[x][0] < dp[x][1]){ swap(edge[x][0], edge[x][1]); swap(dp[x][0], dp[x][1]); } } } if (edge[x][0] == -1) dp[x][0] = g[x]; if (edge[x][1] == -1) dp[x][1] = g[x]; } void solve(int x,int fa){ for (int i = head[x]; i != -1; i = p[i].next){ if (p[i].u == fa) continue; int y = p[i].u; ans[y][0] = dp[y][0] + max(g[x] - max(g[y] - (p[i].value << 1), 0) - (p[i].value << 1), 0); ans[y][1] = dp[y][1] + max(g[x] - max(g[y] - (p[i].value << 1), 0) - (p[i].value << 1), 0); edge[y][2] = i; if (edge[x][0] == i){ if (g[y] >= (p[i].value << 1)) ans[y][2] = ans[x][1] + p[i].value; else ans[y][2] = g[y] - p[i].value + ans[x][1]; } else{ if (g[y] >= (p[i].value << 1)) ans[y][2] = ans[x][0] + p[i].value; else ans[y][2] = g[y] - p[i].value + ans[x][0]; } if (ans[y][2] > ans[y][1]){ swap(ans[y][2], ans[y][1]); swap(edge[y][2], edge[y][1]); } if (ans[y][1] > ans[y][0]){ swap(ans[y][1], ans[y][0]); swap(edge[y][1], edge[y][0]); } g[y] += max(g[x] - max(g[y] - (p[i].value << 1), 0) - (p[i].value << 1), 0); solve(y, x); } } int main(){ int t; t = read(); for (int tcase = 1; tcase <= t; tcase++){ int n; n = read(); for (int i = 1; i <= n; i++){ value[i] = read(); head[i] = -1; g[i] = 0; } for (int i = 0; i < ((n - 1) << 1); i++){ int x, y, z; x = read(); y = read(); z = read(); p[i].u = y; p[i].value = z; p[i].next = head[x]; head[x] = i++; p[i].u = x; p[i].value = z; p[i].next = head[y]; head[y] = i; } dfs(1, 0); ans[1][0] = dp[1][0]; ans[1][1] = dp[1][1]; solve(1, 0); printf("Case #%d:\n", tcase); for (int i = 1; i <= n; i++){ printf("%d\n", ans[i][0]); } } }