题意:求一个字符串中包含字符ch的所有子串
思路:训练的时候想到是用后缀数组,但是不停地tle,最后还是没有ac,事后总结了下相关的性质
(1)一个字符串的所有子串必定是属于某个后缀的前缀, 如s = “acabd”,后缀0包含的子串是“a”, “ac”, “aca”, “acab”, “acabc”,后缀1包含的子串是“c”,“ca”, “cab”, “cabd”,后缀2包含的儿串是“a”, “ab”, “abd”, 后缀3包含的子串是“b”, “bd”, 后缀4包含的子串是“d”;
(2)上述所有后缀的前缀是有重复的,例如后缀0的“a”和后缀2的“a”,那么可以得到一条公式:后缀i贡献的子串个数 = 后缀i长度 - height[i];
(3)后缀i长度 = 字符串长度 - sa[i];
(4)算后缀数组时,一般加上'\0',也就是算出来后sa[0] = len, rank[len] = 0,;(len为字符串长度)
链接:http://acm.hdu.edu.cn/showproblem.php?pid=5769
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; const int maxn = 100005; char s[maxn], s0[10]; int sa[maxn], t[maxn], t2[maxn], c[maxn]; int ran[maxn], h[maxn]; int pos[maxn]; void build_sa(int n, int m) { int i, *x = t, *y = t2; for (i = 0; i < m; i++) c[i] = 0; for (i = 0; i < n; i++) c[x[i] = s[i]]++; for (i = 1; i < m; i++) c[i] += c[i - 1]; for (i = n - 1; i >= 0; i--) sa[--c[x[i]]] = i; for (int k = 1; k <= n; k <<= 1) { int p = 0; for (i = n - k; i < n; i++) y[p++] = i; for (i = 0; i < n; i++) if (sa[i] >= k) y[p++] = sa[i] - k; for (i = 0; i < m; i++) c[i] = 0; for (i = 0; i < n; i++) c[x[y[i]]]++; for (i = 0; i < m; i++) c[i] += c[i - 1]; for (i = n - 1; i >= 0; i--) sa[--c[x[y[i]]]] = y[i]; swap(x, y); p = 1; x[sa[0]] = 0; for (i = 1; i < n; i++) x[sa[i]] = y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + k] == y[sa[i] + k] ? p - 1 : p++; if (p >= n) break; m = p; } } void get_height(int n) { int k = 0; for (int i = 0; i <= n; i++) ran[sa[i]] = i; //记住这里是等于号 for (int i = 0; i < n; i++) { if (k) k--; int j = sa[ran[i] - 1];//这里千万不要写成ran[k] - 1,调了大半天没看出来 while (s[i + k] == s[j + k]) k++; h[ran[i]] = k; } } int main() { int t, cas = 1; scanf("%d", &t); while (t--) { scanf("%s", s0); scanf("%s", s); char ch = s0[0]; int len = strlen(s); build_sa(len + 1, 128); get_height(len); int res; bool flag = false; pos[len - 1] = -1; for (int i = len - 1; i >= 0; i--) { if (s[i] == ch) flag = true, res = i; if (flag) pos[i] = res; else pos[i] = -1; } long long ans = 0; for (int i = 0; i <= len; i++) { if (pos[sa[i]] != -1) ans = ans + len - max(sa[i] + h[i], pos[sa[i]]); } printf("Case #%d: %I64d\n", cas++, ans); } return 0; }