题目大意
 
给定一棵
  
   n
  个节点的树,每个点有一个颜色种类ci。  对于每一个点
  
   x
  ,你需要统计从x出发的所有路径的颜色种类数之和。
 
  
   1≤n≤3×105,0≤ci≤n
  
 
 
题目分治
 
首先这题虚树肯定可以做,这里不讲。  考虑使用点分治,先不考虑有多种颜色。假设我只想统计出现过某一种颜色的路径总数。  对于分治重心
  
   c
  ,在分治过程中做到点x:  
  
   ∙ 
  如果
  
   x
  到c的路径上已经有了这一种颜色,那么
  
   x
  的答案显然就要加上当前分治层的点数减去x所在子树的点数。  
  
   ∙ 
  如果
  
   x
  到c的路径上没有这种颜色,我们就要通过预处理求出所有到
  
   c
  路径(不包括c)包含该种颜色的路径条数统计出来,减去
  
   x
  所在子树中满足同样条件的路径数,加在x的答案上面。  现在考虑将其扩展到多种颜色上。  对于分治重心
  
   c
  ,在分治过程中做到点x:  
  
   ∙ 
  设其到
  
   c
  路径上颜色种类数为cnt,分治层点数减去
  
   x
  所在子树的点数是siz,答案就要加上
  
   cnt×siz
  。  
  
   ∙ 
  对于那些没有出现过的颜色种类,我们考虑先预处理不包含颜色
  
   i
  的到c路径(不包含
  
   c
  )条数fi,那么答案就要加上每种没有出现过的颜色的
  
   f
  值,算的时候由于我们是深搜实现,一开始用sum记录
  
   f
  值和,每新出现一种颜色就减掉对应的f,消失一种颜色就加上,这样就可以快速求出要加的值了。当然,我们还要减去和
  
   x
  同一棵子树的路径,这个在实现的时候,我们再深搜计算一个子树之前先深搜一边把这个子树的路径从f中删掉就好了,深搜计算完之后再深搜加回去就好了。  时间复杂度
  
   O(nlogn)
  。
 
 
代码实现
 
using namespace std;
typedef long long LL;
int read()
{
    
int x=
0,f=
1;
    char ch=getchar();
    
while (!isdigit(ch)) f=ch==
'-'?-
1:f,ch=getchar();
    
while (isdigit(ch)) 
x=
x*10+ch-
'0',ch=getchar();
    
return x*f;
}
int buf[
30];
void 
write(LL 
x)
{
    
if (
x<
0) putchar(
'-'),
x=-
x;
    
for (;
x;
x/=
10) buf[++buf[
0]]=
x;
    
if (!buf[
0]) buf[++buf[
0]]=
0;
    
for (;buf[
0];putchar(
'0'+buf[buf[
0]--]));
}
const 
int N=
300050;
const 
int E=N<<
1;
int last[N],fa[N],size[N],col[N],f[N],ext[N],que[N];
int nxt[E],tov[E];
bool vis[N];
LL ans[N];
int n,tot,head,tail,cur;
LL sum;
void insert(
int x,
int y){tov[++tot]=
y,nxt[tot]=
last[
x],
last[
x]=tot;}
int core(
int og)
{
    
int i,
x,
y,ret,rets=n,tmp;
    
for (head=
0,fa[que[tail=
1]=og]=
0;head<tail;)
        
for (size[
x=que[++head]]=
1,i=
last[
x];i;i=nxt[i])
            
if ((
y=tov[i])!=fa[
x]&&!vis[
y])
                fa[que[++tail]=
y]=
x;
    
for (head=tail;head>
1;--head) size[fa[que[head]]]+=size[que[head]];
    
for (head=
1;head<=tail;++head)
    {
        
for (tmp=size[og]-size[
x=que[head]],i=
last[
x];i;i=nxt[i])
            
if ((
y=tov[i])!=fa[
x]&&!vis[
y]) tmp=max(tmp,size[
y]);
        
if (tmp<rets) ret=
x,rets=tmp;
    }
    
return ret;
}
void dfs(
int x)
{
    size[
x]=
1;
    
for (
int i=
last[
x],
y;i;i=nxt[i])
        
if ((
y=tov[i])!=fa[
x]&&!vis[
y])
            fa[
y]=
x,dfs(
y),size[
x]+=size[
y];
}
void count(
int x,
int *f,
int sig)
{
    
if (!ext[col[
x]]++) f[col[
x]]+=sig
*size[
x],sum+=sig
*size[
x],++cur;
    
for (
int i=
last[
x],
y;i;i=nxt[i])
        
if ((
y=tov[i])!=fa[
x]&&!vis[
y]) count(
y,f,sig);
    
if (!--ext[col[
x]]) --cur;
}
void calc(
int x,
int siz,
int c)
{
    
if (!ext[col[
x]]++) sum-=f[col[
x]],++cur;
    ans[
x]+=
1ll
*siz*cur+sum,ans[c]+=cur;
    
for (
int i=
last[
x],
y;i;i=nxt[i])
        
if ((
y=tov[i])!=fa[
x]&&!vis[
y]) calc(
y,siz,c);
    
if (!--ext[col[
x]]) sum+=f[col[
x]],--cur;
}
void solve(
int x)
{
    
int c=core(
x);
    ++ans[c],size[c]=
1;
    
for (
int i=
last[c],
y;i;i=nxt[i])
        
if (!vis[
y=tov[i]]) fa[
y]=c,dfs(
y),size[c]+=size[
y];
    
for (
int i=
last[c],
y;i;i=nxt[i])
        
if (!vis[
y=tov[i]]) count(
y,f,
1);
    
for (
int i=
last[c],
y;i;i=nxt[i])
        
if (!vis[
y=tov[i]]) count(
y,f,-
1),++ext[col[c]],sum-=f[col[c]],cur=
1,calc(
y,size[c]-size[
y],c),--ext[col[c]],sum+=f[col[c]],cur=
0,count(
y,f,
1);
    
for (
int i=
last[c],
y;i;i=nxt[i])
        
if (!vis[
y=tov[i]]) count(
y,f,-
1);
    vis[c]=
1;
    
for (
int i=
last[c],
y;i;i=nxt[i])
        
if (!vis[
y=tov[i]]) solve(
y);
}
int main()
{
    freopen(
"mushroom.in",
"r",stdin),freopen(
"mushroom.out",
"w",stdout);
    n=
read();
    
for (
int i=
1;i<=n;++i) col[i]=
read();
    
for (
int i=
1,
x,
y;i<n;++i) 
x=
read(),
y=
read(),insert(
x,
y),insert(
y,
x);
    solve(
1);
    
for (
int i=
1;i<=n;++i) 
write(ans[i]),putchar(
'\n');
    fclose(stdin),fclose(stdout);
    
return 0;
}
                
                
                
        
    
                    转载请注明原文地址: https://ju.6miu.com/read-675694.html