题意
给出一棵树,每条边上都有一个长度不超过10的字符串。给出m个询问x y ch,求x到y的路径有多少个字符串的前缀是ch。
n,m<=100000
分析
将每个字符串用一个map离散化,然后在树上建可持久化线段树即可。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<map>
#define N 100005
#define LL long long
using namespace std;
int cnt,dep[N],n,m,last[N],a[
20],sz,tot,root[N],fa[N][
20];
struct edge{
int to,next,a1,a[
15];}e[N*
2];
struct tree{
int l,r,s;}t[N*
300];
map <LL,int> hash;
void addedge(
int u,
int v,
int a1)
{
e[++cnt].to=v;e[cnt].next=last[u];last[u]=cnt;e[cnt].a1=a1;
for (
int i=
0;i<a1;i++) e[cnt].a[i]=a[i];
e[++cnt].to=u;e[cnt].next=last[v];last[v]=cnt;e[cnt].a1=a1;
for (
int i=
0;i<a1;i++) e[cnt].a[i]=a[i];
}
void ins(
int &d,
int p,
int l,
int r,
int x)
{
d=++sz;
t[d].l=t[p].l;t[d].r=t[p].r;t[d].s=t[p].s+
1;
if (l==r)
return;
int mid=(l+r)/
2;
if (x<=mid) ins(t[d].l,t[p].l,l,mid,x);
else ins(t[d].r,t[p].r,mid+
1,r,x);
}
int query(
int d,
int p,
int q,
int l,
int r,
int x)
{
if (l==r)
return t[d].s+t[p].s-
2*t[q].s;
int mid=(l+r)/
2;
if (x<=mid)
return query(t[d].l,t[p].l,t[q].l,l,mid,x);
else return query(t[d].r,t[p].r,t[q].r,mid+
1,r,x);
}
void dfs(
int x)
{
dep[x]=dep[fa[x][
0]]+
1;
for (
int i=
1;i<=
16;i++) fa[x][i]=fa[fa[x][i-
1]][i-
1];
for (
int i=last[x];i;i=e[i].next)
{
if (e[i].to==fa[x][
0])
continue;
fa[e[i].to][
0]=x;
int a1=e[i].a1;
ins(root[e[i].to],root[x],
1,
1000000,e[i].a[
0]);
for (
int j=
1;j<a1;j++) ins(root[e[i].to],root[e[i].to],
1,
1000000,e[i].a[j]);
dfs(e[i].to);
}
}
int get_lca(
int x,
int y)
{
if (dep[x]<dep[y]) swap(x,y);
for (
int i=
16;i>=
0;i--)
if (dep[fa[x][i]]>=dep[y]) x=fa[x][i];
if (x==y)
return x;
for (
int i=
16;i>=
0;i--)
if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][
0];
}
int main()
{
freopen(
"strings.in",
"r",stdin);freopen(
"strings.out",
"w",stdout);
scanf(
"%d",&n);
for (
int i=
1;i<n;i++)
{
int x,y;
char ch[
15];LL w=
0;
scanf(
"%d%d%s",&x,&y,ch);
int len=
strlen(ch);
for (
int j=
0;j<len;j++)
{
w=w*
27+ch[j]-
'a'+
1;
int p=hash[w];
if (!p) hash[w]=p=++tot;
a[j]=p;
}
addedge(x,y,len);
}
dfs(
1);
scanf(
"%d",&m);
for (
int i=
1;i<=m;i++)
{
int x,y;
char ch[
15];
scanf(
"%d%d%s",&x,&y,ch);
int lca=get_lca(x,y),len=
strlen(ch);LL w=
0;
for (
int j=
0;j<len;j++) w=w*
27+ch[j]-
'a'+
1;
int p=hash[w];
if (!p)
{
printf(
"0\n");
continue;
}
else printf(
"%d\n",query(root[x],root[y],root[lca],
1,
1000000,p));
}
return 0;
}
转载请注明原文地址: https://ju.6miu.com/read-12703.html