第一行一个整数n,代表岛屿数量。
接下来n-1行,每行三个整数u,v,w,代表u号岛屿和v号岛屿由一条代价为c的桥梁直接相连,保证1<=u,v<=n且1<=c<=100000。
第n+1行,一个整数m,代表敌方机器能使用的次数。
接下来m行,每行一个整数ki,代表第i次后,有ki个岛屿资源丰富,接下来k个整数h1,h2,…hk,表示资源丰富岛屿的编号。
输出有m行,分别代表每次任务的最小代价。
对于100%的数据,2<=n<=250000,m>=1,sigma(ki)<=500000,1<=ki<=n-1
Stage2 day2
[ Submit][ Status][ Discuss]题解:虚树+树形DP
构建过程如下:
按照原树的dfs序递增顺序遍历选择的节点. 每次遍历节点都把这个节点插到树上.
首先虚树一定要有一个根.一般强制1号点(或者原树的根节点)为虚树的根.
维护一个栈,它表示在我们已经(用之前的那些点)构建完毕的虚树上,以最后一个插入的点为端点的DFS链.
设最后插入的点为p(就是栈顶的点),当前遍历到的点为x.我们想把x插入到我们已经构建的树上去.
求出lca(p,x),记为lca.有两种情况:
1.p和x分立在lca的两棵子树下.
2.lca是p.
对于第二种情况,直接在栈中插入节点x即可,不要连接任何边.
对于第一种情况,要仔细分析.
我们是按照dfs序号遍历的,所以dfn(x)>dfn(p)>dfn(lca).
这说明什么呢? 说明一件很重要的事:我们已经把lca所引领的子树中,p所在的子树全部遍历完了!
这样,我们就直接构建lca引领的,p所在的那个子树. 我们在退栈的时候构建子树.
p所在的子树如果还有其它部分,它一定在之前就构建好了(所有退栈的点都已经被正确地连入树中了),就剩那条链.
如何正确地把p到lca那部分连进去呢?
设栈顶的节点为p,栈顶第二个节点为q.
重复以下操作:
如果dfn(q)>dfn(lca),可以直接连边q->p,然后退一次栈.
如果dfn(q)=dfn(lca),说明q=lca,直接连边lca->p,此时子树已经构建完毕.
如果dfn(q)<dfn(lca),说明lca被p与q夹在中间,此时连边lca->q,退一次栈,再把lca压入栈.此时子树构建完毕.
最后,为了维护dfs链,要把x压入栈.
剩下的就是树形DP的锅了。f[i]表示子树中的所有关键点(题目中给出的点)全部割断的最小代价,如果这个点是给出的点,那么他上面的边一定会被割断,并且子树中的其他点都不必再割,如果不是关键点就对上面的边和所有儿子的代价和取min
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #define N 500003 using namespace std; typedef long long LL; const LL inf=10000000000000LL; int tot,nxt[N],point[N],v[N],c[N],deep[N],fa[N][20],mi[20]; int n,m,k,top,st[N],a[N],pos[N],sz; LL f[N],mn[N]; void add(int x,int y,int z) { tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; c[tot]=z; tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x; c[tot]=z; } void dfs(int x,int f) { pos[x]=++sz; deep[x]=deep[f]+1; for (int i=1;i<=19;i++) { if (deep[x]-mi[i]<0) break; fa[x][i]=fa[fa[x][i-1]][i-1]; } for (int i=point[x];i;i=nxt[i]) { if (v[i]==f) continue; mn[v[i]]=min(mn[x],(LL)c[i]); fa[v[i]][0]=x; dfs(v[i],x); } } int lca(int x,int y) { if (deep[x]<deep[y]) swap(x,y); int k=deep[x]-deep[y]; for (int i=0;i<=19;i++) if ((k>>i)&1) x=fa[x][i]; if (x==y) return x; for (int i=19;i>=0;i--) if (fa[x][i]!=fa[y][i]) { x=fa[x][i]; y=fa[y][i]; } return fa[x][0]; } void build(int x,int y) { if (x==y) return; tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; } void dp(int x) { f[x]=mn[x]; LL sum=0; for (int i=point[x];i;i=nxt[i]){ dp(v[i]); sum+=f[v[i]]; } point[x]=0; if (sum) f[x]=min(f[x],sum); } int cmp(int x,int y) { return pos[x]<pos[y]; } void solve() { scanf("%d",&k); for (int i=1;i<=k;i++) scanf("%d",&a[i]); sort(a+1,a+k+1,cmp); int cnt=0; a[++cnt]=a[1]; for (int i=2;i<=k;i++) if (lca(a[cnt],a[i])!=a[cnt]) a[++cnt]=a[i]; k=cnt; top=0; st[++top]=1; tot=0; for (int i=1;i<=k;i++) { int now=a[i]; int f=lca(st[top],now); while (true) { if (deep[f]>=deep[st[top-1]]) { build(f,st[top--]); if (f!=st[top]) st[++top]=f; break; } build(st[top-1],st[top]); top--; } if (st[top]!=now) st[++top]=now; } while (top-1) build(st[top-1],st[top]),top--; dp(1); printf("%lld\n",f[1]); } int main() { freopen("a.in","r",stdin); scanf("%d",&n); for (int i=1;i<n;i++) { int x,y,z; scanf("%d%d%d",&x,&y,&z); add(x,y,z); } mi[0]=1; for (int i=1;i<=19;i++) mi[i]=mi[i-1]*2; mn[1]=inf; dfs(1,0); scanf("%d",&m); memset(point,0,sizeof(point)); memset(nxt,0,sizeof(nxt)); for (int i=1;i<=m;i++) solve(); } 版本2:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #define N 500003 using namespace std; typedef long long LL; const LL inf=10000000000000LL; int tot,nxt[N],point[N],v[N],deep[N],fa[N][20],mi[20]; int n,m,k,top,st[N],a[N],pos[N],sz; LL f[N],mn[N],c[N],len[N][20]; void add(int x,int y,int z) { tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; c[tot]=z; tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x; c[tot]=z; } void dfs(int x,int f) { pos[x]=++sz; deep[x]=deep[f]+1; for (int i=1;i<=19;i++) { if (deep[x]-mi[i]<0) break; fa[x][i]=fa[fa[x][i-1]][i-1]; len[x][i]=min(len[x][i-1],len[fa[x][i-1]][i-1]); } for (int i=point[x];i;i=nxt[i]) { if (v[i]==f) continue; mn[v[i]]=min(mn[x],(LL)c[i]); fa[v[i]][0]=x; len[v[i]][0]=c[i]; dfs(v[i],x); } } int lca(int x,int y) { if (deep[x]<deep[y]) swap(x,y); int k=deep[x]-deep[y]; for (int i=0;i<=19;i++) if ((k>>i)&1) x=fa[x][i]; if (x==y) return x; for (int i=19;i>=0;i--) if (fa[x][i]!=fa[y][i]) { x=fa[x][i]; y=fa[y][i]; } return fa[x][0]; } LL getlen(int x,int y) { if (deep[x]<deep[y]) swap(x,y); int k=deep[x]-deep[y]; LL l=inf; for (int i=0;i<=19;i++) if ((k>>i)&1) l=min(l,len[x][i]),x=fa[x][i]; if (x==y) return l; for (int i=19;i>=0;i--) if (fa[x][i]!=fa[y][i]) { l=min(l,len[x][i]),x=fa[x][i]; l=min(l,len[y][i]),y=fa[y][i]; } l=min(l,len[x][0]); l=min(l,len[y][0]); return l; } void build(int x,int y) { if (x==y) return; LL t=getlen(x,y); tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; c[tot]=t; // cout<<x<<" "<<y<<" "<<t<<endl; } void dp(int x,LL mn) { f[x]=mn; LL sum=0; for (int i=point[x];i;i=nxt[i]){ dp(v[i],c[i]); sum+=f[v[i]]; } point[x]=0; if (sum) f[x]=min(f[x],sum); } int cmp(int x,int y) { return pos[x]<pos[y]; } void solve() { scanf("%d",&k); for (int i=1;i<=k;i++) scanf("%d",&a[i]); sort(a+1,a+k+1,cmp); int cnt=0; a[++cnt]=a[1]; for (int i=2;i<=k;i++) if (lca(a[cnt],a[i])!=a[cnt]) a[++cnt]=a[i]; k=cnt; top=0; st[++top]=1; tot=0; for (int i=1;i<=k;i++) { int now=a[i]; int f=lca(st[top],now); while (true) { if (deep[f]>=deep[st[top-1]]) { build(f,st[top--]); if (f!=st[top]) st[++top]=f; break; } build(st[top-1],st[top]); top--; } if (st[top]!=now) st[++top]=now; } while (top-1) build(st[top-1],st[top]),top--; dp(1,inf); printf("%lld\n",f[1]); } int main() { freopen("a.in","r",stdin); scanf("%d",&n); for (int i=1;i<n;i++) { int x,y,z; scanf("%d%d%d",&x,&y,&z); add(x,y,z); } mi[0]=1; for (int i=1;i<=19;i++) mi[i]=mi[i-1]*2; memset(len,127,sizeof(len)); mn[1]=inf; dfs(1,0); scanf("%d",&m); memset(point,0,sizeof(point)); memset(nxt,0,sizeof(nxt)); for (int i=1;i<=m;i++) solve(); }