题目大意
给定一棵
n
个节点的树,每个点有一个颜色种类ci。 对于每一个点
x
,你需要统计从x出发的所有路径的颜色种类数之和。
1≤n≤3×105,0≤ci≤n
题目分治
首先这题虚树肯定可以做,这里不讲。 考虑使用点分治,先不考虑有多种颜色。假设我只想统计出现过某一种颜色的路径总数。 对于分治重心
c
,在分治过程中做到点x:
∙
如果
x
到c的路径上已经有了这一种颜色,那么
x
的答案显然就要加上当前分治层的点数减去x所在子树的点数。
∙
如果
x
到c的路径上没有这种颜色,我们就要通过预处理求出所有到
c
路径(不包括c)包含该种颜色的路径条数统计出来,减去
x
所在子树中满足同样条件的路径数,加在x的答案上面。 现在考虑将其扩展到多种颜色上。 对于分治重心
c
,在分治过程中做到点x:
∙
设其到
c
路径上颜色种类数为cnt,分治层点数减去
x
所在子树的点数是siz,答案就要加上
cnt×siz
。
∙
对于那些没有出现过的颜色种类,我们考虑先预处理不包含颜色
i
的到c路径(不包含
c
)条数fi,那么答案就要加上每种没有出现过的颜色的
f
值,算的时候由于我们是深搜实现,一开始用sum记录
f
值和,每新出现一种颜色就减掉对应的f,消失一种颜色就加上,这样就可以快速求出要加的值了。当然,我们还要减去和
x
同一棵子树的路径,这个在实现的时候,我们再深搜计算一个子树之前先深搜一边把这个子树的路径从f中删掉就好了,深搜计算完之后再深搜加回去就好了。 时间复杂度
O(nlogn)
。
代码实现
using namespace std;
typedef long long LL;
int read()
{
int x=
0,f=
1;
char ch=getchar();
while (!isdigit(ch)) f=ch==
'-'?-
1:f,ch=getchar();
while (isdigit(ch))
x=
x*10+ch-
'0',ch=getchar();
return x*f;
}
int buf[
30];
void
write(LL
x)
{
if (
x<
0) putchar(
'-'),
x=-
x;
for (;
x;
x/=
10) buf[++buf[
0]]=
x;
if (!buf[
0]) buf[++buf[
0]]=
0;
for (;buf[
0];putchar(
'0'+buf[buf[
0]--]));
}
const
int N=
300050;
const
int E=N<<
1;
int last[N],fa[N],size[N],col[N],f[N],ext[N],que[N];
int nxt[E],tov[E];
bool vis[N];
LL ans[N];
int n,tot,head,tail,cur;
LL sum;
void insert(
int x,
int y){tov[++tot]=
y,nxt[tot]=
last[
x],
last[
x]=tot;}
int core(
int og)
{
int i,
x,
y,ret,rets=n,tmp;
for (head=
0,fa[que[tail=
1]=og]=
0;head<tail;)
for (size[
x=que[++head]]=
1,i=
last[
x];i;i=nxt[i])
if ((
y=tov[i])!=fa[
x]&&!vis[
y])
fa[que[++tail]=
y]=
x;
for (head=tail;head>
1;--head) size[fa[que[head]]]+=size[que[head]];
for (head=
1;head<=tail;++head)
{
for (tmp=size[og]-size[
x=que[head]],i=
last[
x];i;i=nxt[i])
if ((
y=tov[i])!=fa[
x]&&!vis[
y]) tmp=max(tmp,size[
y]);
if (tmp<rets) ret=
x,rets=tmp;
}
return ret;
}
void dfs(
int x)
{
size[
x]=
1;
for (
int i=
last[
x],
y;i;i=nxt[i])
if ((
y=tov[i])!=fa[
x]&&!vis[
y])
fa[
y]=
x,dfs(
y),size[
x]+=size[
y];
}
void count(
int x,
int *f,
int sig)
{
if (!ext[col[
x]]++) f[col[
x]]+=sig
*size[
x],sum+=sig
*size[
x],++cur;
for (
int i=
last[
x],
y;i;i=nxt[i])
if ((
y=tov[i])!=fa[
x]&&!vis[
y]) count(
y,f,sig);
if (!--ext[col[
x]]) --cur;
}
void calc(
int x,
int siz,
int c)
{
if (!ext[col[
x]]++) sum-=f[col[
x]],++cur;
ans[
x]+=
1ll
*siz*cur+sum,ans[c]+=cur;
for (
int i=
last[
x],
y;i;i=nxt[i])
if ((
y=tov[i])!=fa[
x]&&!vis[
y]) calc(
y,siz,c);
if (!--ext[col[
x]]) sum+=f[col[
x]],--cur;
}
void solve(
int x)
{
int c=core(
x);
++ans[c],size[c]=
1;
for (
int i=
last[c],
y;i;i=nxt[i])
if (!vis[
y=tov[i]]) fa[
y]=c,dfs(
y),size[c]+=size[
y];
for (
int i=
last[c],
y;i;i=nxt[i])
if (!vis[
y=tov[i]]) count(
y,f,
1);
for (
int i=
last[c],
y;i;i=nxt[i])
if (!vis[
y=tov[i]]) count(
y,f,-
1),++ext[col[c]],sum-=f[col[c]],cur=
1,calc(
y,size[c]-size[
y],c),--ext[col[c]],sum+=f[col[c]],cur=
0,count(
y,f,
1);
for (
int i=
last[c],
y;i;i=nxt[i])
if (!vis[
y=tov[i]]) count(
y,f,-
1);
vis[c]=
1;
for (
int i=
last[c],
y;i;i=nxt[i])
if (!vis[
y=tov[i]]) solve(
y);
}
int main()
{
freopen(
"mushroom.in",
"r",stdin),freopen(
"mushroom.out",
"w",stdout);
n=
read();
for (
int i=
1;i<=n;++i) col[i]=
read();
for (
int i=
1,
x,
y;i<n;++i)
x=
read(),
y=
read(),insert(
x,
y),insert(
y,
x);
solve(
1);
for (
int i=
1;i<=n;++i)
write(ans[i]),putchar(
'\n');
fclose(stdin),fclose(stdout);
return 0;
}
转载请注明原文地址: https://ju.6miu.com/read-675694.html