大体思路
1.求出每个元素在树中的深度
2.用st表预处理的方法处理出f[i][j],f[i][j]表示元素i上方第2^j行对应的祖先是谁
3.将较深的点向上挪,直到两结点的深度相同
4.深度相同后,祖先可能就在上方,再走几步就到了,于是两个点同时向上移
具体的方法和代码贴在下面 ↓
具体
1.求出每个元素在树中的深度
//求每个节点在树中的深度
void dfs(
int pos,
int pre)
//pre是pos的父节点
{
for(
int i=
0;i<v[pos].size;i++)
//枚举pos的子节点
{
register int t=
v[pos][i];
if(t==pre)
continue;
//防止死循环
f[t][
0]=pos;dep[t]=dep[pos]+
1;
dfs(t,pos);//再从子节点向后枚举
}
}
2.用st表预处理的方法处理出f[i][j]
//求f数组(st表预处理)
for(
int i=
1;(
1<<i)<=n;i++
)
for(
int j=
1;j<=n;j++
)
f[j][i]=f[f[j][i-
1]][i-
1];
//f[i][j]表示元素i上方第2^j行对应的祖先是谁
3.先比较a,b两点哪个较深,将较深的点赋值给a
//把a节点变为a,b中较深的一个节点
int query(
int a,
int b)
{
if(dep[a]<
dep[b])swap(a,b);
}
将较深的点向上挪,直到两结点的深度相同
//使a和b在同一个深度上
for(
int i=
20;i>=
0;i--
)
if(dep[f[a][i]]>=
dep[b])
a=
f[a][i];
//倒着循环是因为向上走的步数只会越来越小
4.深度相同后,祖先可能就在上方,再走几步就到了,于是两个点同时向上移
//同一深度后,再向上找公共祖先
for(
int i=
20;i>=
0;i--
)
if(f[a][i]!=
f[b][i])
{
a=
f[a][i];
b=
f[b][i];
}
全部代码
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>
using namespace std;
vector<
int> v[
41000];
vector<
int> w[
41000];
int f[
41000][
21];
//f[i][j]表示i点向上2^j层的祖先
int g[
41000][
21];
//g[i][j]表示i点到从i向上2^j层的祖先的距离
int dep[
41000];
int n,m;
void dfs(
int pos,
int pre,
int depth)
{
dep[pos]=
depth;
for(
int i=
0;i<v[pos].size();i++
)
{
int t=
v[pos][i];
if(t==pre)
continue;
f[t][0]=
pos;
g[t][0]=
w[pos][i];
dfs(t,pos,depth+
1);
}
}
int query(
int a,
int b)
{
int sum=
0;
if(dep[a]<dep[b]) swap(a,b);
//深度较深的点
for(
int i=
20;i>=
0;i--)
//找到a在深度dep[b]处的祖先
{
if(dep[f[a][i]]>=
dep[b])
{
sum+=g[a][i];
//a到该祖先的距离
a=
f[a][i];
}
}
if(a==b)
return sum;
//挪到相同深度后如果在同一点直接return
int x;
for(
int i=
20;i>=
0;i--)
//否则a和b一起往上蹦跶
{
if(f[a][i]!=
f[b][i])
{
sum+=
g[a][i];
a=
f[a][i];
sum+=
g[b][i];
b=
f[b][i];
}
}
return sum+g[a][
0]+g[b][
0];
//最后蹦跶到最近公共祖先的下一层,所以要再加上上一层
}
int main()
{
int T;
cin>>
T;
while(T--
)
{
scanf("%d%d",&n,&
m);
memset(dep,-
1,
sizeof dep);
//多组数据我们初始化
memset(f,
0,
sizeof f);
memset(g,0,
sizeof g);
for(
int i=
0;i<n;i++)
//md
v[i].clear(),w[i].clear();
for(
int i=
1;i<n;i++
)
{
int x,y,c;
cin>>x>>y>>
c;
v[x].push_back(y);
w[x].push_back(c);
v[y].push_back(x);
w[y].push_back(c);
}
int xxx=v[
1].size();
dfs(1,
0,
1);
//dfs处理出每个点的深度,以及各种...
for(
int i=
1;
1<<i<=n;i++
)
for(
int j=
1;j<=n;j++
)
f[j][i]=f[f[j][i-
1]][i-
1],
g[j][i]=g[f[j][i-
1]][i-
1]+g[j][i-
1];
for(
int i=
1;i<=m;i++
)
{
int x,y;
cin>>x>>
y;
if(x==y) cout<<
"0"<<
endl;
else cout<<query(x,y)<<
endl;
}
}
return 0;
}
转载请注明原文地址: https://ju.6miu.com/read-672072.html