题目描述
传送门 题意:给出一棵树,树边上有一个字符,问每个点的子树中最长的合法路径长度。所谓的合法路径长度就是给路径中的所有字符重新组合后可以是回文串。
题解
一道细节比较多的dsu on the tree 首先某一个子树中的路径可以分为经过根的和不经过根的,不经过根的可以通过儿子更新过来,所以对于每一棵子树只统计经过根的路径。 经过根的路径又可以分成两种,第一种是到根就结束的,也就是说是根到它子树里某一个点的一条链,这个可以在dfs的时候统计。 第二种是路径的两个端点分别在根的两个儿子的子树里。令s(i)表示1到i的路径的状态,状态是一个22位2进制数,每一位表示这个字符在路径上出现次数(偶数/奇数)。f(i)表示状态为i的当前可以统计的点的最大深度。并且显然如果能是回文串的话所有的字符数量为奇数个的最多只有一个。这样对于每一个点,其重儿子的子树已经被统计过了,那对于每一个轻儿子的子树,先遍历一遍,对于每一个遍历到的点,首先枚举可以是回文串的状态(22+1种),然后通过这个状态计算出可以与这个点配对的状态,然后计算深度即可。 时间复杂度
O(nlogn∗23)
代码
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
using namespace std;
#define N 500005
int n,Max,inf;
int tot,point[N],nxt[N],v[N];
char c[N];
int size[N],son[N],h[N],s[N],f[N*
20],ans[N];
void add(
int x,
int y)
{
++tot; nxt[tot]=point[x]; point[x]=tot; v[tot]=y;
}
void getson(
int x,
int fa)
{
size[x]=
1;h[x]=h[fa]+
1;
if (x!=
1) s[x]=s[fa]^(
1<<(c[x]-
'a'));
for (
int i=point[x];i;i=nxt[i])
{
getson(v[i],x);
size[x]+=size[v[i]];
if (size[v[i]]>size[son[x]]) son[x]=v[i];
}
}
void calc(
int rt,
int x)
{
int state=s[x];
Max=max(Max,f[state]+h[x]-
2*h[rt]);
if ((s[x]^s[rt])==
0) Max=max(Max,h[x]-h[rt]);
for (
int i=
0;i<
22;++i)
{
state=(
1<<i)^s[x];
Max=max(Max,f[state]+h[x]-
2*h[rt]);
if ((s[x]^s[rt])==(
1<<i)) Max=max(Max,h[x]-h[rt]);
}
for (
int i=point[x];i;i=nxt[i])
calc(rt,v[i]);
}
void change(
int x,
int opt)
{
if (opt) f[s[x]]=max(f[s[x]],h[x]);
else f[s[x]]=inf;
for (
int i=point[x];i;i=nxt[i])
change(v[i],opt);
}
void dfs(
int x,
int k)
{
for (
int i=point[x];i;i=nxt[i])
if (v[i]!=son[x])
dfs(v[i],
0);
if (son[x]) dfs(son[x],
1);
Max=
0;
int state=s[x];
Max=max(Max,f[state]-h[x]);
for (
int i=
0;i<
22;++i)
{
state=(
1<<i)^s[x];
Max=max(Max,f[state]-h[x]);
}
for (
int i=point[x];i;i=nxt[i])
if (v[i]!=son[x])
{
calc(x,v[i]);
change(v[i],
1);
}
ans[x]=Max;
if (!k)
{
for (
int i=point[x];i;i=nxt[i])
change(v[i],
0);
f[s[x]]=inf;
}
else f[s[x]]=max(f[s[x]],h[x]);
}
void recalc(
int x)
{
for (
int i=point[x];i;i=nxt[i])
{
recalc(v[i]);
ans[x]=max(ans[x],ans[v[i]]);
}
}
int main()
{
scanf(
"%d\n",&n);
for (
int i=
2;i<=n;++i)
{
int fa;
scanf(
"%d %c\n",&fa,&c[i]);
add(fa,i);
}
getson(
1,
0);
memset(f,
128,
sizeof(f));inf=f[
0];
dfs(
1,
0);
recalc(
1);
for (
int i=
1;i<=n;++i)
printf(
"%d%c",ans[i],
" \n"[i==n]);
}
转载请注明原文地址: https://ju.6miu.com/read-600143.html