bzoj 3611: [Heoi2014]大工程 (虚树+树形DP)

    xiaoxiao2021-03-25  61

    3611: [Heoi2014]大工程

    Time Limit: 60 Sec   Memory Limit: 512 MB Submit: 1218   Solved: 530 [ Submit][ Status][ Discuss]

    Description

    国家有一个大工程,要给一个非常大的交通网络里建一些新的通道。  我们这个国家位置非常特殊,可以看成是一个单位边权的树,城市位于顶点上。  在 2 个国家 a,b 之间建一条新通道需要的代价为树上 a,b 的最短路径。  现在国家有很多个计划,每个计划都是这样,我们选中了 k 个点,然后在它们两两之间 新建 C(k,2)条 新通道。 现在对于每个计划,我们想知道:  1.这些新通道的代价和  2.这些新通道中代价最小的是多少  3.这些新通道中代价最大的是多少

    Input

    第一行 n 表示点数。

     接下来 n-1 行,每行两个数 a,b 表示 a 和 b 之间有一条边。 点从 1 开始标号。 接下来一行 q 表示计划数。 对每个计划有 2 行,第一行 k 表示这个计划选中了几个点。  第二行用空格隔开的 k 个互不相同的数表示选了哪 k 个点。

    Output

    输出 q 行,每行三个数分别表示代价和,最小代价,最大代价。 

    Sample Input

    10 2 1 3 2 4 1 5 2 6 4 7 5 8 6 9 7 10 9 5 2 5 4 2 10 4 2 5 2 2 6 1 2 6 1

    Sample Output

    3 3 3 6 6 6 1 1 1 2 2 2 2 2 2

    HINT

    n<=1000000 

    q<=50000并且保证所有k之和<=2*n 

    Source

    鸣谢佚名上传

    [ Submit][ Status][ Discuss]

    题解:虚树+树形DP

    构建虚树应该没什么好说的了,关键就是树形DP。

    对于每个点维护所选点到该点的最小值,最大值,路径和,以及该点子树中所选点的个数。

    进行转移和更新答案即可。具体过程见代码

    #include<iostream> #include<cstring> #include<cstdio> #include<algorithm> #include<cmath> #define N 1000003 using namespace std; typedef long long LL; const LL inf=1000000000000LL; int n,m,k,point[N],nxt[N*2],v[N*2],c[N*2],st[N],pos[N]; int mi[21],fa[N][21],tot,top,a[N],deep[N],sz,mark[N]; LL len[N][21],sum[N],mx[N],mn[N],size[N],mnx,sumx,mxx; void add(int x,int y) { tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x; } void dfs(int x,int f) { deep[x]=deep[f]+1; pos[x]=++sz; for (int i=1;i<=20;i++) { if (deep[x]-mi[i]<0) break; fa[x][i]=fa[fa[x][i-1]][i-1]; len[x][i]=len[x][i-1]+len[fa[x][i-1]][i-1]; } for (int i=point[x];i;i=nxt[i]) { if (v[i]==f) continue; fa[v[i]][0]=x; len[v[i]][0]=1; dfs(v[i],x); } } int cmp(int x,int y) { return pos[x]<pos[y]; } LL getlen(int x,int y) { if (deep[x]<deep[y]) swap(x,y); int k=deep[x]-deep[y]; LL sum=0; for (int i=0;i<=20;i++) if ((k>>i)&1) sum+=len[x][i],x=fa[x][i]; if (x==y) return sum; for (int i=20;i>=0;i--) if (fa[x][i]!=fa[y][i]) { sum+=len[x][i],sum+=len[y][i]; x=fa[x][i],y=fa[y][i]; } sum+=len[x][0],sum+=len[y][0]; return sum; } void build(int x,int y) { if (x==y) return; tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; c[tot]=getlen(x,y); //cout<<x<<" "<<y<<" "<<c[tot]<<endl; } int lca(int x,int y) { if (deep[x]<deep[y]) swap(x,y); int k=deep[x]-deep[y]; for (int i=0;i<=20;i++) if ((k>>i)&1) x=fa[x][i]; if (x==y) return x; for (int i=20;i>=0;i--) if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; return fa[x][0]; } void dp(int x) { mn[x]=inf; mx[x]=-inf; size[x]=0; sum[x]=0; LL siz=0; bool pd=false; for (int i=point[x];i;i=nxt[i]) { dp(v[i]); size[x]+=size[v[i]]; mnx=min(mnx,mn[x]+mn[v[i]]+c[i]); mxx=max(mxx,mx[x]+mx[v[i]]+c[i]); mn[x]=min(mn[x],mn[v[i]]+c[i]); mx[x]=max(mx[x],mx[v[i]]+c[i]); if (pd) sumx+=sum[x]*size[v[i]]+sum[v[i]]*siz+siz*size[v[i]]*c[i]; sum[x]+=sum[v[i]]+size[v[i]]*c[i]; siz+=size[v[i]]; pd=true; } if (!pd&&mark[x]) mn[x]=mx[x]=0,size[x]=1; else if (mark[x]) { mnx=min(mnx,mn[x]); mn[x]=0; mxx=max(mx[x],mxx); size[x]++; sumx+=sum[x]; } //cout<<x<<" "<<mx[x]<<" "<<mn[x]<<" "<<sum[x]<<" "<<size[x]<<endl; point[x]=0; } void solve() { scanf("%d",&k); for (int i=1;i<=k;i++) scanf("%d",&a[i]),mark[a[i]]=1; sort(a+1,a+k+1,cmp); tot=0; st[++top]=1; for (int i=1;i<=k;i++) { int now=a[i]; int f=lca(st[top],now); while (true) { if (deep[f]>=deep[st[top-1]]) { build(f,st[top--]); if (f!=st[top]) st[++top]=f; break; } build(st[top-1],st[top]); top--; } if (now!=st[top]) st[++top]=now; } while (top-1) build(st[top-1],st[top]),top--; mnx=inf; mxx=-inf; sumx=0; dp(1); printf("%lld %lld %lld\n",sumx,mnx,mxx); for (int i=1;i<=k;i++) mark[a[i]]=0; } int main() { freopen("a.in","r",stdin); scanf("%d",&n); mi[0]=1; for (int i=1;i<=21;i++) mi[i]=mi[i-1]*2; for (int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); add(x,y); } dfs(1,0); scanf("%d",&m); memset(point,0,sizeof(point)); memset(nxt,0,sizeof(nxt)); for (int i=1;i<=m;i++) solve(); }

    转载请注明原文地址: https://ju.6miu.com/read-35816.html

    最新回复(0)