【SinGuLaRiTy-1006】 ZKW-RegTree ZKW线段树

    xiaoxiao2021-03-25  38

     Code By WenJian

    {作者还在理解开区间的问题,一旦有新的想法将会及时更新}

    【关于ZKW线段树】

    Zkw线段树是清华大学张昆玮发明非递归线段树的写法。实践证明,这种线段树常数更小,速度更快,写起来也并不复杂。

    【建树】

    ZKW线段树本质上就是依赖于满二叉树中父节点与子节点的编号关系。

    如上图中的一个简单的满二叉树,我们可以发现如下规律:

    1>父子节点编号关系: 假设父节点的编号为 n ,那么,它的两个子节点的编号就分别为 n*2(n<<1)、n*2+1(n<<1|1);

    2>二叉树层数与底层叶子节点数的关系:假设这个二叉树的层数为 m ,那么,这个二叉树的底层叶子节点数(由于是满二叉树,这也就是所有的叶子节点了)就是2^(m-1),同时,我们还可以知道,所有叶子节点中编号最小的,即在这个满二叉树左下角的叶子节点的编号也为 2^(m-1);

    通过以上的两大关系,我们在存储一个数组的初始数据时,就可以直接将初始数据存储在满二叉树的底层。假设数组中有 x 个元素,那么这 x 个元素在这个满二叉树中的编号就是2^(m-1)~2^(m-1)+x-1,访问起来就很方便了。

    于是就有了建树代码:(其中n代表的是初始数组中的元素个数,M代表的是最底层的叶子节点个数)

    inline void Build() { for(M=1;M<n;M<<=1) ;//由于要构建一个满二叉树,所以我们不能直接让二叉树的叶子节点数等于元素个数,M可能会大于n;本层循环使底层叶子节点数在满足“满二叉树的前提下最小” for(int i=M;i<n+M;i++)//由于M也同样是本层最左侧叶子节点的编号,所以直接从这里开始赋值 Tree[i]=Read(); } (许多其他的博客总是会在这里自问自答“建完了吗?没有。”,对于这种有点SB的行为,我表示无法理解)

    不过确实,到目前为止,建树还未完成,我们还需要从下往上更新其它节点的值。当然,知道了父子节点编号的关系,这个操作就非常好用了。、

    inline void upgrade() { for(int i=M-1;i;i--) { Tree[i]=Tree[i<<1|1]+Tree[i<<1];//维护为区间和 } } 当然,你也可以将其维护为最大值,最小值之类的,代码都大同小异:

    最大值:

    Tree[i]=max(Tree[i<<1|1],Tree[i<<1]);
    最小值:
    Tree[i]=min(Tree[i<<1|1],Tree[i<<1]);

    到目前为止,我们才算是完成了ZKW线段树的建树工作。

    <ZKW线段树中的差分思想>

    在建ZKW线段树的过程中,可以用的Tree[i]表示父子节点的差值,也同样可以达到存储数据的目的。

    void Build(int n) { for(M=1;M<=n+1;M<<=1); for(int i=M;i<M+n;i++) Tree[i]=in(); for(int i=M-1;i;--i) Tree[i]=min(Tree[i<<1],Tree[i<<1|1]),Tree[i<<1]-=Tree[i],Tree[i<<1|1]-=Tree[i]; } 觉得稍微复杂了一些?但这样的存储方式可以为RMQ问题做准备。

    <关于空间>

    我们都知道,在建线段树时,需要开的数组(或是结构体)的大小是 4n ;在这里 , 我们来计算一下ZKW线段树的所需要的空间。(设初始数据中元素个数为 n )

    最好的情况: 当 n=2^k 时,由于此时刚好可以把最底层排满,则数组大小大概为 2n ;

    最坏的情况: 当 n=2^k+1时,即底层刚好多出一个,仍需要把底层排满时,则数组大小大概为 4n-1 ;

    因此,即使是最坏的情况,ZKW线段树也比普通线段树的空间表现要好。

    【单点查询】

    假设数组中有 x 个元素,二叉树层数为 m ,那么这 x 个元素在这个满二叉树中的编号就是2^(m-1)~2^(m-1)+x-1,访问起来很方便。

    <单点查询-差分版>

    其实差分版单点查询写起来也不是很复杂,也比较利于理解。

    void Sum(int x,int res=0) { while(x) res+=Tree[x],x>>=1; return res; }.

    【区间查询】

    <区间求和>

    先看一下代码:

    int Sum(int s,int t,int Ans=0) { s+=M-1,t+=M-1; Ans+=d[s]+d[t]; if(s==t) return Ans-=d[s]; for(;s^t^1;s>>=1,t>>=1)//s^t^1 : 如果s与t在同一个父亲节点以下,就说明我们已经累加到这棵树的根部了。当s与t在同一个父亲节点下时,t-s=1,那么s^t=1,s^t^1=0,此时就退出循环。 { if(~s&1)//这里在下面解释 Ans+=d[s^1]; //d[s^1]是d[s]的兄弟 if(t&1)//这里在下面解释 Ans+=d[t^1];//d[t^1]是d[t]的兄弟 } return Ans; }
    <关于代码中的 ~s&1 与 t&1>

    首先我们可以将这两个位运算式转化为好理解一点的式子:

    if(~s&1) ->  if(s%2==0)

    if(t&1) -> if(t%2!=0) 也就是说,这里是在判断奇偶,结合满二叉树的编号规律我们很容易发现:若编号为奇,则为右儿子;若编号为偶,则为左儿子。那么,这里判断左/右儿子有什么用呢?

    我们看上面的这幅图。如果我们知道要查询的区间的两个端点为编号4、7的节点,由于这是满二叉树,因此我们可以在图中寻找位于4号节点右边且位于7号节点左边的节点,这些节点一定位于我们要查询的区间之中。而我们又知道,在两个兄弟节点A,B之中,若A为左儿子,那么B一定在A的右边;若A为右儿子,那么B一定在A的左边。也就是说,如果我们知道区间的两个端点是左儿子还是右儿子,我们就可以知道它们的兄弟节点是否在区间的覆盖范围之内。又由于在ZKW线段树中,我们已经将父节点维护成为包含其子节点信息的节点,因此不用担心有漏算的情况。(要注意是开区间还是闭区间)

    我们不妨画个图来验证一下:

    (注:图中的红点为累加过的点,橙色为访问过的点)

    图中的累加节点覆盖了所有的查询范围。

    <区间查询最大值>

    和 区间求和 的代码思路差不多,直接上代码:

    void Sum(int s,int t,int L=0,int R=0) { for(s=s+M-1,t=t+M-1;s^t^1;s>>=1,t>>=1) { L+=d[s],R+=d[t]; if(~s&1) L=max(L,d[s^1]); if(t&1) R=max(R,d[t^1]); } int res=max(L,R); while(s) res+=d[s>>=1]; }
    <区间查询最小值>
    void Add(int s,int t,int v,int A=0) { for(s=s+M-1,t=t+M-1;s^t^1;s>>=1,t>>=1) { if(~s&1) d[s^1]+=v; if(t&1) d[t^1]+=v; A=min(d[s],d[s^1]);d[s]-=A,d[s^1]-=A,d[s>>1]+=A; A=min(d[t],d[t^1]);d[t]-=A,d[t^1]-=A,d[t>>1]+=A; } while(s) A=min(d[s],d[s^1]),d[s]-=A,d[s^1]-=A,d[s>>=1]+=A; }

    【单点更新】

    void Change(int x,int v) { d[x=M+x-1]+=v; while(x) d[x>>=1]=d[x<<1]+d[x<<1|1]; }

    【区间更新】

    举个模板题例子。结合题目来看看代码吧。

    区间修改的RMQ问题
    题目描述
    给出N(1 ≤ N ≤ 50,000)个数的序列A,下标从1到N,每个元素值均不超过1000。有两种操作: (1)  Q i j:询问区间[i, j]之间的最大值与最小值的差值 (2) C i j k:将区间[i, j]中的每一个元素增加k,k是一个整数,k的绝对值不超过1000。 一共有M (1 ≤ M ≤ 200,000) 次操作,对每个Q操作,输出一行,回答提问。

    输入
    第1行:2个整数N, M 第2行:N个元素 接下来M行,每行一个操作
    输出

    对每个Q操作,在一行上输出答案

    样例输入
    5 4 1 2 3 4 5 Q 2 4 C 1 1 1 C 1 3 2 Q 1 5
    样例输出
    2 1 #include<cstdio> #include<algorithm> using namespace std; #define lson pos << 1 #define rson pos << 1 | 1 #define fa(x) (x >> 1) const int MAXN = 50000; int d1[MAXN << 2], d2[MAXN << 2], M = 1, n, m; // d1 -> max // d2 -> min inline void Read(int &Ret){ char ch; int flg = 1; while(ch = getchar(), ch < '0' || ch > '9') if(ch == '-') flg = -1; Ret = ch - '0'; while(ch = getchar(), ch >= '0' && ch <= '9') Ret = Ret * 10 + ch - '0'; Ret *= flg; } void build(int n){ while(M < n) M <<= 1; int pos = M --; while(pos <= M + n){ Read(d1[pos]); d2[pos] = d1[pos]; pos ++; } pos = M; while(pos){ d1[pos] = max(d1[lson], d1[rson]); d2[pos] = min(d2[lson], d2[rson]); d1[lson] -= d1[pos]; d1[rson] -= d1[pos]; d2[lson] -= d2[pos]; d2[rson] -= d2[pos]; pos --; } } inline void Insert(int L, int R, int v){//区间更新 L += M; R += M; int A; if(L == R){ d1[L] += v; d2[L] += v; while(L > 1){ A = max(d1[L], d1[L ^ 1]); d1[L] -= A; d1[L ^ 1] -= A; d1[fa(L)] += A; A = min(d2[L], d2[L ^ 1]); d2[L] -= A; d2[L ^ 1] -= A; d2[fa(L)] += A; L >>= 1; } return; } d1[L] += v; d1[R] += v; d2[L] += v; d2[R] += v; while(L ^ R ^ 1){ if(~L & 1) d1[L ^ 1] += v, d2[L ^ 1] += v; if(R & 1) d1[R ^ 1] += v, d2[R ^ 1] += v; A = max(d1[L], d1[L ^ 1]); d1[L] -= A; d1[L ^ 1] -= A; d1[fa(L)] += A; A = max(d1[R], d1[R ^ 1]); d1[R] -= A; d1[R ^ 1] -= A; d1[fa(R)] += A; A = min(d2[L], d2[L ^ 1]); d2[L] -= A; d2[L ^ 1] -= A; d2[fa(L)] += A; A = min(d2[R], d2[R ^ 1]); d2[R] -= A; d2[R ^ 1] -= A; d2[fa(R)] += A; L >>= 1; R >>= 1; } while(L > 1){ A = max(d1[L], d1[L ^ 1]); d1[L] -= A; d1[L ^ 1] -= A; d1[fa(L)] += A; A = min(d2[L], d2[L ^ 1]); d2[L] -= A; d2[L ^ 1] -= A; d2[fa(L)] += A; L >>= 1; } } inline int getans(int L, int R){ L += M; R += M; int ans1 = 0, ans2 = 0; if(L == R){ while(L){ ans1 += d1[L]; ans2 += d2[L]; L >>= 1; } return ans1 - ans2; } int l1 = 0, r1 = 0, l2 = 0, r2 = 0; while(L ^ R ^ 1){ l1 += d1[L]; r1 += d1[R]; l2 += d2[L]; r2 += d2[R]; if(~L & 1) l1 = max(l1, d1[L ^ 1]), l2 = min(l2, d2[L ^ 1]); if(R & 1) r1 = max(r1, d1[R ^ 1]), r2 = min(r2, d2[R ^ 1]); L >>= 1; R >>= 1; } l1 += d1[L]; r1 += d1[R]; l2 += d2[L]; r2 += d2[R]; ans1 = max(l1, r1); ans2 = min(l2, r2); while(L > 1){ L >>= 1; ans1 += d1[L]; ans2 += d2[L]; } //printf("max=%d min=%d\n",ans1, ans2); return ans1 - ans2; } int main(){ int a, b, c; char id[3]; Read(n); Read(m); build(n); while(m --){ scanf("%s",id); Read(a); Read(b); switch(id[0]){ case 'C': Read(c), Insert(a, b, c); break; default: printf("%d\n",getans(a, b)); } } return 0; } By SinGuLaRiTy.

    转载请注明原文地址: https://ju.6miu.com/read-26125.html

    最新回复(0)