题目描述
传送门
题解
裸的dp可以得到20pts 令f(i)表示将i点子树中所有关键点割掉的最小代价 那么若i为关键点,f(i)=i的父边权;若i不是关键点,f(i)=所有儿子的f之和 与 i的父边权取min
那么对于所有的关键点和它们的lca造出一棵虚树,连的边为树链上所有边的最小值 同样的方法dp就行了
代码
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<ctime>
using namespace std;
#define LL long long
#define N 250005
#define sz 18
int n,m,k;
int tot,point[N],nxt[N*
2],v[N*
2],c[N*
2];
int isl[N],flag[N];
int dfs_clock,h[N],in[N],out[N],deep[N],last[N],f[N][sz+
3],s[N][sz+
3],
stack[N],top;
LL dp[N];
void add(
int x,
int y,
int z)
{
++tot; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; c[tot]=z;
}
void build(
int x,
int fa)
{
in[x]=++dfs_clock;deep[x]=deep[fa]+
1;
for (
int i=
1;i<sz;++i)
{
f[x][i]=f[f[x][i-
1]][i-
1];
s[x][i]=min(s[x][i-
1],s[f[x][i-
1]][i-
1]);
}
for (
int i=point[x];i;i=nxt[i])
if (v[i]!=fa)
{
f[v[i]][
0]=x;
s[v[i]][
0]=c[i];
build(v[i],x);
}
out[x]=dfs_clock;
}
int lca(
int x,
int y)
{
if (deep[x]<deep[y]) swap(x,y);
int cha=deep[x]-deep[y];
for (
int i=
0;i<sz;++i)
if ((cha>>i)&
1) x=f[x][i];
if (x==y)
return x;
for (
int i=sz-
1;i>=
0;--i)
if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][
0];
}
void treedp(
int x)
{
dp[x]=
0;
for (
int i=point[x];i;i=nxt[i])
{
last[v[i]]=c[i];
treedp(v[i]);
dp[x]+=dp[v[i]];
}
if (isl[x]==m) dp[x]=(LL)last[x];
if (x!=
1) dp[x]=min(dp[x],(LL)last[x]);
point[x]=
0;
}
int cmp(
int a,
int b)
{
return in[a]<in[b];
}
int find(
int x,
int y)
{
int cha=deep[y]-deep[x];
int Min=
100001;
for (
int i=
0;i<sz;++i)
if ((cha>>i)&
1) Min=min(Min,s[y][i]),y=f[y][i];
return Min;
}
void solve()
{
scanf(
"%d",&k);
for (
int i=
1;i<=k;++i)
{
scanf(
"%d",&h[i]);
isl[h[i]]=flag[h[i]]=m;
}
sort(h+
1,h+k+
1,cmp);h[
0]=k;
for (
int i=
2;i<=k;++i)
{
int r=lca(h[i],h[i-
1]);
if (flag[r]!=m)
{
flag[r]=m;
h[++h[
0]]=r;
}
}
if (flag[
1]!=m) flag[
1]=m,h[++h[
0]]=
1;
sort(h+
1,h+h[
0]+
1,cmp);
tot=
0;
stack[top=
1]=
1;
for (
int i=
2;i<=h[
0];++i)
{
while (in[h[i]]<in[
stack[top]]||in[h[i]]>out[
stack[top]])
--top;
int Min=find(
stack[top],h[i]);
add(
stack[top],h[i],Min);
stack[++top]=h[i];
}
treedp(
1);
printf(
"%lld\n",dp[
1]);
}
int main()
{
scanf(
"%d",&n);
for (
int i=
1;i<n;++i)
{
int x,y,z;
scanf(
"%d%d%d",&x,&y,&z);
add(x,y,z),add(y,x,z);
}
build(
1,
0);
memset(point,
0,
sizeof(point));
scanf(
"%d",&m);
while (m) solve(),--m;
return 0;
}
转载请注明原文地址: https://ju.6miu.com/read-33114.html