JZOJ 5052. 【GDOI2017模拟二试4.12】旅游路线

    xiaoxiao2021-03-31  31

    Description

    A君准备在Z国进行一次旅行,Z国中有n个城市,城市从1到n进行编号,其中1号城市为Z国首都。Z国的旅行交通网由n-1条单向道路构成,并且从任何一个城市出发都可以通过旅行网到达首都。 一条旅行交通网中的旅行线路,可以用线路上所经过的城市来描述,如{v1,v2,v3,……,vm},它表示一条经过了m个城市的旅行路线,且城市vi到城市vi+1有一条单向道路相连。 两个城市是相似的,当且仅当他们所连接的道路数相同。 若两条路线{u1,u2,……,up}与{v1,v2,……,vq},若p=q且∀1 ≤ i ≤ p,城市 u i 与 v i 是相似的,则 A君认为这两条旅行路线也是相似的。 现在A君想知道共有多少种不同的旅行路线,相似的若干条旅行路线只算一种。

    Input

    第一行一个整数n表示Z国城市个数 接下来n-1行每行两个整数x,y,表示一条从x到y的单向道路

    Output

    仅一行一个整数表示答案

    Sample Input

    3 2 1 3 1

    Sample Output

    3

    Data Constraint

    20%的数据:n ≤ 100 另有40%的数据:每个城市所连接的道路不超过20条 100%的数据:1≤n≤10^5

    分析

    简单的来说就是给一颗字典树,然后让你求上面有多少个不同的子串(不一定从根节点开始)。 那么只要在这棵字典树上建一棵广义的后缀自动机,然后把每个节点对应的字符串数量加起来即可。

    代码

    #include <bits/stdc++.h> #define N 100005 struct NOTE { int to,next; }e[N]; int cnt; int next[N]; int n; int d[N]; int fa[N * 2]; int max[N * 2]; std::map <int,int> ch[N * 2]; int read() { int x = 0,f = 1; char ch = getchar(); while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); } while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); } return x * f; } void add(int x,int y) { e[++cnt].to = y; e[cnt].next = next[x]; next[x] = cnt; } int size; int ins(int last,int x) { if (ch[last][x]) { int p = last; int np = ch[last][x]; if (max[np] == max[p] + 1) last = np; else { int q = ++size; max[q] = max[p] + 1; ch[q] = ch[np]; fa[q] = fa[np]; fa[np] = last = q; for (;ch[p][x] == np; p = fa[p]) ch[p][x] = q; } return last; } int p,q,np,nq; p = last; last = np = ++size; max[np] = max[p] + 1; for (; !ch[p][x] && p; p = fa[p]) ch[p][x] = np; if (!p) fa[np] = 1; else { q = ch[p][x]; if (max[q] == max[p] + 1) fa[np] = q; else { nq = ++size; max[nq] = max[p] + 1; ch[nq] = ch[q]; fa[nq] = fa[q]; fa[q] = fa[np] = nq; for (; ch[p][x] == q; p = fa[p]) ch[p][x] = nq; } } return last; } void dfs(int x,int p) { int tmp = ins(p,d[x]); for (int i = next[x]; i; i = e[i].next) dfs(e[i].to,tmp); } int tmp[N]; int tot; void ls() { for (int i = 1; i <= n; i++) tmp[++tot] = d[i]; std::sort(tmp + 1,tmp + tot + 1); tot = std::unique(tmp + 1,tmp + tot + 1) - tmp - 1; for (int i = 1; i <= n; i++) d[i] = std::lower_bound(tmp + 1,tmp + tot + 1,d[i]) - tmp - 1; } int main() { freopen("route.in","r",stdin); freopen("route.out","w",stdout); n = read(); for (int i = 1; i < n; i++) { int x = read(); int y = read(); d[x]++; d[y]++; add(y,x); } ls(); size = 1; dfs(1,1); long long ans = 0; for (int i = 1; i <= size; i++) ans += max[i] - max[fa[i]]; printf("%lld\n",ans); }
    转载请注明原文地址: https://ju.6miu.com/read-665467.html

    最新回复(0)