Description Give a tree with n vertices,each edge has a length(positive integer less than 1001). Define dist(u,v)=The min distance between node u and v. Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k. Write a program that will count how many pairs which are valid for a given tree. Input The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l. The last test case is followed by two zeros. Output For each test case output the answer on a single line. 题目大意 给定一棵树,求出满足距离不大于k的点对的个数。 解题思想 最开始的想法就是做n次DFS然后累计答案,但是显然超时。 于是乎,我们想到如果能维护一个数据结构能够快速的知道小于k-disy的x的个数(x,y是枚举的点对,disx表示x到x,y最近祖先的距离,要满足disx+disy<=k,就是disx<=k-disy),马上就想到treap,treap中维护的数值显然是到最近公共祖先的距离,但是每次进行合并时又涉及到+w,比较麻烦,所以维护到根的距离比较方便(累计答案的时候注意加两倍的最近祖先到根的距离)。合并其实很简单,把小树往大树里放就可以了,启发式合并其实就是暴力的一个一个放。也许有人会问为什么不会超时,很明显,最坏的情况是两个一样大的树进行合并,所以对于每个节点最多就合并log(n)次,询问也是log(n)的,所以总复杂度是O(nlog^2(n))。操作时要注意先询问答案后合并,否则会算重。
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; const int maxn=10005,maxm=20005; struct jz{ int x,s,w,ran,l,r; }a[maxn*14]; int lnk[maxn],nxt[maxm],son[maxm],w[maxm],ro[maxn],tot,n,K,m,ans; bool vis[maxn]; void add(int x,int y,int z){nxt[++tot]=lnk[x];lnk[x]=tot;son[tot]=y;w[tot]=z;} void Putdata(int k){a[k].s=a[a[k].l].s+a[a[k].r].s+a[k].w;} void rturn(int &k){ int t=a[k].l;a[k].l=a[t].r;a[t].r=k; a[t].s=a[k].s;Putdata(k);k=t; } void lturn(int &k){ int t=a[k].r;a[k].r=a[t].l;a[t].l=k; a[t].s=a[k].s;Putdata(k);k=t; } void Insert(int &k,int x){ if (k==0){k=++m;a[k].s=a[k].w=1;a[k].ran=rand();a[k].x=x;return;} a[k].s++; if (x==a[k].x) a[k].w++;else if (x<a[k].x){ Insert(a[k].l,x); if (a[a[k].l].ran<a[k].ran) rturn(k); }else{ Insert(a[k].r,x); if (a[a[k].r].ran<a[k].ran) lturn(k); } } void Join(int &k1,int k2){ if (k2==0) return; for (int i=1;i<=a[k2].w;i++) Insert(k1,a[k2].x); Join(k1,a[k2].l);Join(k1,a[k2].r); } int Asksum(int k,int x){ if (k==0) return 0; if (x==a[k].x) return a[k].w+a[a[k].l].s;else if (x<a[k].x) return Asksum(a[k].l,x);else return a[k].w+a[a[k].l].s+Asksum(a[k].r,x); } int Count(int k1,int k2,int x){ if (k2==0) return 0; return a[k2].w*Asksum(k1,x-a[k2].x)+Count(k1,a[k2].l,x)+Count(k1,a[k2].r,x); } void DFS(int x,int dep){ vis[x]=1; for (int j=lnk[x];j;j=nxt[j])if (!vis[son[j]]){ DFS(son[j],dep+w[j]); if (a[ro[x]].s<a[ro[son[j]]].s) swap(ro[son[j]],ro[x]); ans+=Count(ro[x],ro[son[j]],K+2*dep);Join(ro[x],ro[son[j]]); } ans+=Asksum(ro[x],K+dep); Insert(ro[x],dep); } int main(){ freopen("exam.in","r",stdin); freopen("exam.out","w",stdout); while (1){ scanf("%d%d",&n,&K); if (n==0&&K==0) return 0; memset(a,0,sizeof(a)); memset(ro,0,sizeof(ro)); memset(vis,0,sizeof(vis)); memset(lnk,0,sizeof(lnk)); tot=ans=m=0;int x,y,z; for (int i=1;i<n;i++){scanf("%d%d%d",&x,&y,&z);add(x,y,z);add(y,x,z);} DFS(1,0); printf("%d\n",ans); } }