题目描述
传送门
题解
令size(i)表示i子树里有多少个关键点 令sum(i)表示i子树中所有关键点到i的距离和 令Max(i)表示i子树中所有关键点到它的最长链,_Max(i)次长链,Min(i)最短链,_Min(i)次短链 这些都非常好维护,第二问和第三问也很好计算,用最和次拼一下就行了 对于第一问的话,在dp的时候维护一下当前size和sum的乘积就行了 将所有的关键点和它们的lca建出一棵虚树,边权为两点之间的距离 然后按照上面的dp就行了 dp的时候要格外注意子树的根是否是关键点以及儿子的个数
代码
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
using namespace std;
#define LL long long
#define N 1000005
#define sz 20
int n,q,k,dfs_clock,top;
int tot,point[N],nxt[N*
2],v[N*
2],c[N*
2];
int pt[N],key[N],flag[N],
stack[N],h[N],in[N],out[N],f[N][sz+
3],size[N];
LL sum[N],Max[N],_Max[N],Min[N],_Min[N],ans1,ans2,ans3;
const LL inf=
1e18;
void add(
int x,
int y,
int z)
{
++tot; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; c[tot]=z;
}
void build(
int x,
int fa)
{
h[x]=h[fa]+
1;in[x]=++dfs_clock;
for (
int i=
1;i<sz;++i) f[x][i]=f[f[x][i-
1]][i-
1];
for (
int i=point[x];i;i=nxt[i])
if (v[i]!=fa)
{
f[v[i]][
0]=x;
build(v[i],x);
}
out[x]=++dfs_clock;
}
int cmp(
int a,
int b)
{
return in[a]<in[b];
}
int lca(
int x,
int y)
{
if (h[x]<h[y]) swap(x,y);
int cha=h[x]-h[y];
for (
int i=
0;i<sz;++i)
if ((cha>>i)&
1) x=f[x][i];
if (x==y)
return x;
for (
int i=sz-
1;i>=
0;--i)
if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][
0];
}
void treedp(
int x)
{
size[x]=
0;
sum[x]=
0;
Max[x]=_Max[x]=
0;
Min[x]=_Min[x]=inf;
if (key[x]==q) size[x]=
1,Min[x]=
0;
int cnt=
0;
for (
int i=point[x];i;i=nxt[i])
{
++cnt;
treedp(v[i]);
ans1+=(LL)size[v[i]]*sum[x]+(LL)size[x]*(sum[v[i]]+(LL)c[i]*(LL)size[v[i]]);
size[x]+=size[v[i]];
sum[x]+=sum[v[i]]+(LL)c[i]*(LL)size[v[i]];
if (Max[v[i]]+(LL)c[i]>Max[x])
{
_Max[x]=Max[x];
Max[x]=Max[v[i]]+(LL)c[i];
}
else _Max[x]=max(_Max[x],Max[v[i]]+(LL)c[i]);
if (Min[v[i]]+(LL)c[i]<Min[x])
{
_Min[x]=Min[x];
Min[x]=Min[v[i]]+(LL)c[i];
}
else _Min[x]=min(_Min[x],Min[v[i]]+(LL)c[i]);
}
if (key[x]==q||cnt>
1)
{
ans2=min(ans2,Min[x]+_Min[x]);
ans3=max(ans3,Max[x]+_Max[x]);
}
point[x]=
0;
}
int main()
{
scanf(
"%d",&n);
for (
int i=
1;i<n;++i)
{
int x,y;
scanf(
"%d%d",&x,&y);
add(x,y,
1),add(y,x,
1);
}
build(
1,
0);
memset(point,
0,
sizeof(point));
scanf(
"%d",&q);
while (q)
{
scanf(
"%d",&k);
for (
int i=
1;i<=k;++i)
{
scanf(
"%d",&pt[i]);
key[pt[i]]=flag[pt[i]]=q;
}
sort(pt+
1,pt+k+
1,cmp);pt[
0]=k;
for (
int i=
2;i<=k;++i)
{
int r=lca(pt[i-
1],pt[i]);
if (flag[r]!=q)
{
flag[r]=q;
pt[++pt[
0]]=r;
}
}
if (flag[
1]!=q) flag[
1]=q,pt[++pt[
0]]=
1;
sort(pt+
1,pt+pt[
0]+
1,cmp);
tot=
0;
stack[top=
1]=
1;
for (
int i=
2;i<=pt[
0];++i)
{
while (in[pt[i]]<in[
stack[top]]||in[pt[i]]>out[
stack[top]])
--top;
add(
stack[top],pt[i],h[pt[i]]-h[
stack[top]]);
stack[++top]=pt[i];
}
ans1=
0;ans2=inf;ans3=
0;
treedp(
1);
printf(
"%lld %lld %lld\n",ans1,ans2,ans3);
--q;
}
}
转载请注明原文地址: https://ju.6miu.com/read-33042.html