Description
A君准备在Z国进行一次旅行,Z国中有n个城市,城市从1到n进行编号,其中1号城市为Z国首都。Z国的旅行交通网由n-1条单向道路构成,并且从任何一个城市出发都可以通过旅行网到达首都。 一条旅行交通网中的旅行线路,可以用线路上所经过的城市来描述,如{v1,v2,v3,……,vm},它表示一条经过了m个城市的旅行路线,且城市vi到城市vi+1有一条单向道路相连。 两个城市是相似的,当且仅当他们所连接的道路数相同。 若两条路线{u1,u2,……,up}与{v1,v2,……,vq},若p=q且∀1 ≤ i ≤ p,城市 u i 与 v i 是相似的,则 A君认为这两条旅行路线也是相似的。 现在A君想知道共有多少种不同的旅行路线,相似的若干条旅行路线只算一种。
第一行一个整数n表示Z国城市个数 接下来n-1行每行两个整数x,y,表示一条从x到y的单向道路
Output
仅一行一个整数表示答案
3 2 1 3 1
Sample Output
3
Data Constraint
20%的数据:n ≤ 100 另有40%的数据:每个城市所连接的道路不超过20条 100%的数据:1≤n≤10^5
分析
简单的来说就是给一颗字典树,然后让你求上面有多少个不同的子串(不一定从根节点开始)。 那么只要在这棵字典树上建一棵广义的后缀自动机,然后把每个节点对应的字符串数量加起来即可。
代码
struct NOTE
{
int to,
next;
}e[N];
int cnt;
int next[N];
int n;
int d[N];
int fa[N *
2];
int max[N *
2];
std::
map <
int,
int> ch[N *
2];
int read()
{
int x =
0,f =
1;
char ch = getchar();
while (ch <
'0' || ch >
'9')
{
if (ch ==
'-')
f = -
1;
ch = getchar();
}
while (ch >=
'0' && ch <=
'9')
{
x =
x *
10 + ch -
'0';
ch = getchar();
}
return x * f;
}
void add(
int x,
int y)
{
e[++cnt].to =
y;
e[cnt].
next =
next[
x];
next[
x] = cnt;
}
int size;
int ins(
int last,
int x)
{
if (ch[
last][
x])
{
int p =
last;
int np = ch[
last][
x];
if (max[np] == max[p] +
1)
last = np;
else
{
int q = ++size;
max[
q] = max[p] +
1;
ch[
q] = ch[np];
fa[
q] = fa[np];
fa[np] =
last =
q;
for (;ch[p][
x] == np; p = fa[p])
ch[p][
x] =
q;
}
return last;
}
int p,
q,np,nq;
p =
last;
last = np = ++size;
max[np] = max[p] +
1;
for (; !ch[p][
x] && p; p = fa[p])
ch[p][
x] = np;
if (!p)
fa[np] =
1;
else
{
q = ch[p][
x];
if (max[
q] == max[p] +
1)
fa[np] =
q;
else
{
nq = ++size;
max[nq] = max[p] +
1;
ch[nq] = ch[
q];
fa[nq] = fa[
q];
fa[
q] = fa[np] = nq;
for (; ch[p][
x] ==
q; p = fa[p])
ch[p][
x] = nq;
}
}
return last;
}
void dfs(
int x,
int p)
{
int tmp = ins(p,d[
x]);
for (
int i =
next[
x]; i; i = e[i].
next)
dfs(e[i].to,tmp);
}
int tmp[N];
int tot;
void ls()
{
for (
int i =
1; i <= n; i++)
tmp[++tot] = d[i];
std::
sort(tmp +
1,tmp + tot +
1);
tot = std::unique(tmp +
1,tmp + tot +
1) - tmp -
1;
for (
int i =
1; i <= n; i++)
d[i] = std::lower_bound(tmp +
1,tmp + tot +
1,d[i]) - tmp -
1;
}
int main()
{
freopen(
"route.in",
"r",stdin);
freopen(
"route.out",
"w",stdout);
n =
read();
for (
int i =
1; i < n; i++)
{
int x =
read();
int y =
read();
d[
x]++;
d[
y]++;
add(
y,
x);
}
ls();
size =
1;
dfs(
1,
1);
long long ans =
0;
for (
int i =
1; i <= size; i++)
ans += max[i] - max[fa[i]];
printf(
"%lld\n",ans);
}
转载请注明原文地址: https://ju.6miu.com/read-665467.html