题目大意
有 n 个宇宙,每个宇宙都连出去一条单向边。边的长度都是 1。
现在要新加一些单向边(长度为 1),使得从 1 号节点出发到任意节点的最短路长度不超过 k。求最少加多少边。
n<=5e5, k<=2e4
题解
它是个环套树(森林),那么树上的叶子节点肯定要跟 1 号节点连边。然后通过叶子往上推,树的部分就做完了。
然后单独看每个环,问题变成:有一个环,其中有些点已经染过色了,你要用长度为 k 的链去染这个环,求最少的链数。
把环破成链,倍长一下。我们枚举一个起点,从这个起点开始往后跳(即预处理一个 next 数组,表示第 i 个点往后走 k 步之后第一个没被染色的点是谁),跳到距离大于环长为止,跳的次数就是这个点的答案。
假设有 i< j 且 next[i]=next[j],那么 j 跳到 next[j] 的时候,后面的事情跟 i 是一样的(当然 j 可能还要继续往后跳)。所以我们用并查集把跳过的缩起来,这样就可以保证每个点只会被跳一次了。
代码
using namespace std;
const
int maxn=
5e5+
5;
int n,K,ans;
int tot,go[maxn],
next[maxn],f1[maxn],fa[maxn];
void ins(
int x,
int y)
{
go[++tot]=
y;
next[tot]=f1[
x];
f1[
x]=tot;
}
int d[maxn],com[maxn];
bool roll[maxn];
void topo()
{
int j=
0;
fo(i,
1,n)
if (!com[i]) d[++j]=i;
for(
int i=
1; i<=j; i++)
{
if (--com[fa[d[i]]]==
0) d[++j]=fa[d[i]];
}
fo(i,
1,n)
if (com[i]) roll[i]=
1;
}
int tim[maxn];
bool bz[maxn];
void dfs(
int k)
{
bz[k]=
1;
tim[k]=K+
500;
for(
int p=f1[k]; p; p=
next[p])
{
dfs(go[p]);
tim[k]=min(tim[k],tim[go[p]]+
1);
}
if (k==
1) tim[k]=
0;
else if (tim[k]>K && !roll[k])
{
tim[k]=
1;
ans++;
}
}
int ga[
2*maxn],stp[
2*maxn];
int get(
int x)
{
if (ga[
x]==
x)
return x;
int t=ga[
x];
ga[
x]=get(ga[
x]);
stp[
x]+=stp[t];
return ga[
x];
}
int c
0,c[
2*maxn],f[
2*maxn];
// f 就是上面说的
next
void calc(
int x)
{
c[c
0=
1]=
x;
for(
int i=fa[
x]; i!=
x; i=fa[i]) c[++c
0]=i;
fo(i,
1,c
0) c[c
0+i]=c[i];
fo(i,
1,c
0) dfs(c[i]);
fo(i,
1,
2*c0) tim[c[i]]=min(tim[c[i]],tim[c[i-
1]]+
1);
fo(i,
1,
2*c0) ga[i]=i, stp[i]=
0;
fd(i,
2*c0,
2*c0-K+
1) f[i]=
2*c0+
1;
int last=
2*c0+
1;
fd(i,
2*c0,K+
1)
{
if (tim[c[i]]>K)
last=i;
f[i-K]=
last;
}
int nmin=n+
500;
fo(i,
1,c
0)
if (tim[c[i]]>K)
{
int ans1=
0,
last=
0;
for(
int j=i; j && j-i+
1<=c
0; j=f[j])
{
int t2=get(j);
ans1+=
1+stp[j];
if (
last)
{
int t1=get(
last);
ga[t1]=t2;
stp[t1]+=stp[j]+
1;
}
j=t2;
last=j;
}
nmin=min(nmin,ans1);
}
ans+=(nmin==n+
500) ?
0 :nmin ;
}
int main()
{
scanf(
"%d %d",&n,&K);
fo(i,
1,n)
{
int x,
y;
scanf(
"%d %d",&
x,&
y);
fa[
x]=
y;
com[
y]++;
}
topo();
fo(i,
1,n)
if (!roll[i]) ins(fa[i],i);
fo(i,
1,n)
if (roll[i] && !bz[i]) calc(i);
printf(
"%d\n",ans);
}
转载请注明原文地址: https://ju.6miu.com/read-673428.html