POJ 2104
题意
给定1到n的排列,每次询问某一区间内的第k小值。
样例输入
7 3 1 5 2 6 3 7 4 2 5 3 4 4 1 1 7 3
样例输出
5 6 3
主席树介绍
可持久化线段树,函数式线段树。 有点抽象,能够理解但还不是很熟练,代码不长,但是非常简练,有很多技巧,目前当做黑箱。
可持久化:每次操作尽量用新节点表示而不是修改原节点,这样就能保留所有历史信息。 函数式:函数式编程里变量常常是不变的,线段树的函数式写法就是这样。
我们用区间k小值来解释。
一些预处理
离散化(排序+去重),其实本道题不需要这个操作。 下面的代码用到了STL的很多技巧,用unique()函数去重,用lower_bound()重新映射a数组。
for (i =
1;i <= n; i++) scanf(
"%d",&a[i]);
for (i =
1;i <= n; i++) b[i] = a[i];
sort(b+
1,b+n+
1);
k = unique(b+
1,b+n+
1) - (b+
1);
for (i =
1;i <= n; i++) a[i] = lower_bound(b+
1,b+k+
1,a[i])-b;
假设求整个区间的k小值
这个问题可以用AVL树做,但是这里介绍一种类似平衡树的方法。假设某个节点的区间为[l,r],则这个节点记录的是在a数组中有多少个a[i]满足l<=a[i]<=r。这样搜索第k小值时,如果左孩子数量小于k则k小值在左子树中,反之则在右子树中。复杂度log(n)。
对于任意区间[L,R]
建立n棵线段树,每棵维护[1,i]的数字出现情况。 显然这n棵线段树每个节点代表的区间都是一样的,所以这n棵线段树同构。 用第R棵线段树去“减”第(L-1)棵线段树,得出来的结果就是区间[L,R]的情况,对这棵树套用一遍上面求整个区间的方法就可以求出[L,R]中的k小值。
如何节约空间
上面的方法看起来还是比较具体的,但是会MLE(n棵线段树)。 下面的优化就是主席树的精髓:如何扔掉重复的节点。有点抽象,这段话看懂了就比较轻松了。
我们发现,第i棵线段树和第i+1棵线段树的区别在于加入了a[i+1]这个数,而a[i+1]在第i棵树上从根出发向下走,走过的节点+1就变成了第i+1棵线段树。(你可以自己画一下看看有什么不同)
也就是说相邻两棵线段树之间不同节点个数至多为log(n)个,换句话说剩下这么多的节点都是一样的! 那么重复的节点就可以扔掉了。比如说一个节点的左孩子是重复的,那么我不需要多开一个节点,而是直接连到前一棵树上。 看起来比较复杂,但是编程中有很多技巧,最后代码比普通线段树还短。 P.S. 怕以后忘记这里写的会很详细。
sol
预处理这里就不再写了。
建树
现在连建树都要重新写了TAT。 其实只要建一棵空树即可,后面的树都是连到这棵树上。 但是后面再update和query的时候有一个问题:左孩子和右孩子并不能简单的乘2和乘2加1,如何解决?
//root[i]表示第i棵树的根的位置
void build(int l,int r,int &rt)
{
rt = ++tot;
sum[rt] =
0;
if (l == r)
return;
int m = (l + r) >>
1;
build(l,m,ls[rt]);
build(m+
1,r,rs[rt]);
}
...
tot =
0;
build(
1,k,root[
0]);
用最朴素的方法:一个一个累加! 这里有一个技巧就是用了&,也就是说等到搜到这个点的时候自然会把这个点的位置给传回来。这个技巧剩下了不少代码,在后面的update和query中可以自己体会。
更新
//ls表示左孩子位置 rs表示右孩子位置
last表示前一棵树、当前节点的位置
void update(
int l,
int r,
int &rt,
int last,
int p)
{
rt = ++tot;
ls[rt] = ls[
last]; rs[rt] = rs[
last];
//暂时两个孩子都连到前一棵树的对应孩子上
sum[rt] = sum[
last] +
1;
//这一步可以解释是哪
log(n)个点的值发生了修改!
if (l == r)
return;
int m = (l + r) >>
1;
if (p <=
m) update(l,
m,ls[rt],ls[
last],p);
else update(
m+
1,r,rs[rt],rs[
last],p);
//修改的那个节点开辟出一个新节点 ls/rs会回传新的节点的位置!前面讲到过
}
...
for (i =
1;i <= n; i++) update(
1,k,root[i],root[i-
1],a[i]);
这样一来就把这“n棵线段树”都建好了。可以看出虽然节点总数为nlog(n),但是却把所有的情况都记录下来了,这就是“可持久化”。
查询
int query(
int ss,
int tt,
int l,
int r,
int k)
{
if (l == r)
return l;
int m = (l + r) >>
1;
int cnt = sum[ls[tt]] - sum[ls[ss]];
//用第tt棵线段树减去第ss棵线段树
if (k <= cnt)
return query(ls[ss],ls[tt],l,
m,k);
else return query(rs[ss],rs[tt],
m+
1,r,k-cnt);
}
...
while (
q--)
{
scanf(
"%d%d%d",&ql,&qr,&qk);
int res = query(root[ql-
1],root[qr],
1,k,qk);
printf(
"%d\n",b[res]);
}
有了前面的铺垫,查询就比较简单了。
完整代码
using namespace std;
int a[N],b[N],root[N
*20],ls[N
*20],rs[N
*20],sum[N
*20];
int n,
q,i,tot,k,ql,qr,qk;
void build(
int l,
int r,
int &rt)
{
rt = ++tot;
sum[rt] =
0;
if (l == r)
return;
int m = (l + r) >>
1;
build(l,
m,ls[rt]);
build(
m+
1,r,rs[rt]);
}
void update(
int l,
int r,
int &rt,
int last,
int p)
{
rt = ++tot;
ls[rt] = ls[
last]; rs[rt] = rs[
last];
sum[rt] = sum[
last] +
1;
if (l == r)
return;
int m = (l + r) >>
1;
if (p <=
m) update(l,
m,ls[rt],ls[
last],p);
else update(
m+
1,r,rs[rt],rs[
last],p);
}
int query(
int ss,
int tt,
int l,
int r,
int k)
{
if (l == r)
return l;
int m = (l + r) >>
1;
int cnt = sum[ls[tt]] - sum[ls[ss]];
if (k <= cnt)
return query(ls[ss],ls[tt],l,
m,k);
else return query(rs[ss],rs[tt],
m+
1,r,k-cnt);
}
int main()
{
cin>>n>>
q;
for (i =
1;i <= n; i++) scanf(
"%d",&a[i]);
for (i =
1;i <= n; i++) b[i] = a[i];
sort(b+
1,b+n+
1);
k = unique(b+
1,b+n+
1) - (b+
1);
for (i =
1;i <= n; i++) a[i] = lower_bound(b+
1,b+k+
1,a[i])-b;
tot =
0;
build(
1,k,root[
0]);
for (i =
1;i <= n; i++) update(
1,k,root[i],root[i-
1],a[i]);
while (
q--)
{
scanf(
"%d%d%d",&ql,&qr,&qk);
int res = query(root[ql-
1],root[qr],
1,k,qk);
printf(
"%d\n",b[res]);
}
return 0;
}
转载请注明原文地址: https://ju.6miu.com/read-660223.html