采药人的药田是一个树状结构,每条路径上都种植着同种药材。 采药人以自己对药材独到的见解,对每种药材进行了分类。大致分为两类,一种是阴性的,一种是阳性的。 采药人每天都要进行采药活动。他选择的路径是很有讲究的,他认为阴阳平衡是很重要的,所以他走的一定是两种药材数目相等的路径。采药工作是很辛苦的,所以他希望他选出的路径中有一个可以作为休息站的节点(不包括起点和终点),满足起点到休息站和休息站到终点的路径也是阴阳平衡的。他想知道他一共可以选择多少种不同的路径。
第1行包含一个整数N。 接下来N-1行,每行包含三个整数a_i、b_i和t_i,表示这条路上药材的类型。
输出符合采药人要求的路径数目。
对于100%的数据,N ≤ 100,000。
题解:点分治
把0边权改成-1,那么如果一条路径的总权值为0,那么这条路径阴阳平衡。对于中间是否有中转站的问题,最基本的思路就是,判断路径上是否有两个前缀和相同的点。然后对于有和没有的路径分开统计,然后计算答案。
但是处理的细节非常多。例如如果路径是一条链,那么终点和起点都是0,还需要判断中间是否还有权值为0的点,而且我在处理时没有加入起点,所以对于链这种路径,计算子树时不需要消去影响。但是用子树消除影响的时候又需要加入子树的根。总之细节很恶心。。。。
#include<iostream> #include<cstdio> #include<algorithm> #include<cmath> #include<cstring> #define N 200003 #define inf 1000000000 #define LL long long using namespace std; int tot,point[N],nxt[N],v[N],c[N],sz,root,cnt,d[N],ti,pt; int n,size[N],vis[N],cl[N],f[N],x1[N]; LL ans,hs[N],hn[N]; struct data{ int x,k,pos; }a[N]; void add(int x,int y,int z) { tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; c[tot]=z; tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x; c[tot]=z; } void getroot(int x,int fa) { size[x]=1; f[x]=0; for (int i=point[x];i;i=nxt[i]) { if (v[i]==fa||vis[v[i]]) continue; getroot(v[i],x); size[x]+=size[v[i]]; f[x]=max(f[x],size[v[i]]); } f[x]=max(f[x],sz-size[x]); if (f[x]<f[root]) root=x; } void solve(int x,int fa) { bool pd=false; ++cnt; a[cnt].x=d[x]; a[cnt].k=cl[d[x]+100000]; a[cnt].pos=x; cl[d[x]+100000]++; for (int i=point[x];i;i=nxt[i]) { if (vis[v[i]]||v[i]==fa) continue; d[v[i]]=d[x]+c[i]; solve(v[i],x); } cl[d[x]+100000]--; } void calc(int x,int now,int opt) { d[x]=now; cnt=0; pt=x; if (opt==-1) cl[100000]++; solve(x,0); cl[100000]=0; for (int i=1;i<=cnt;i++) hs[a[i].x+100000]=hn[a[i].x+100000]=0,hs[100000-a[i].x]=hn[100000-a[i].x]=0; int k=0; //cout<<x<<endl; //for (int i=1;i<=cnt;i++) cout<<a[i].x<<" "<<a[i].k<<endl; for (int i=1;i<=cnt;i++) { if (a[i].pos==x&&opt==1) continue; if (a[i].k>=2&&a[i].x==0&&opt==1) ans+=opt; if (a[i].k) { hs[a[i].x+100000]++; if (hs[a[i].x+100000]==1) x1[++k]=a[i].x; } else hn[a[i].x+100000]++; } LL sum=0; k=unique(x1+1,x1+k+1)-x1-1; for (int i=1;i<=k;i++) { int t=100000-x1[i]; if (x1[i]) ans+=(LL)(hs[x1[i]+100000]*hn[t])*opt,sum+=hs[x1[i]+100000]*hs[t]; else ans+=(LL)(hs[x1[i]+100000]*hn[t]+hs[t]*(hs[t]-1)/2)*opt; } ans+=opt*sum/2; //cout<<x<<" "<<ans<<"!"<<endl; } void dfs(int x) { calc(x,0,1); vis[x]=1; for (int i=point[x];i;i=nxt[i]) { if (vis[v[i]]) continue; calc(v[i],c[i],-1); sz=size[v[i]]; root=0; getroot(v[i],x); dfs(root); } } int main() { freopen("a.in","r",stdin); freopen("my.out","w",stdout); scanf("%d",&n); for (int i=1;i<n;i++) { int x,y,z; scanf("%d%d%d",&x,&y,&z); if (!z) z=-1; add(x,y,z); } f[0]=inf; root=0; sz=n; getroot(1,0); dfs(root); printf("%lld\n",ans); }
