0%

线段树

线段树

建树与维护

对于每个父亲节点i,它的两个儿子节点为2i2i+1,故写函数lsrs来获取当前节点的儿子节点:

1
2
inline int ls(int p){ return p*2; }
inline int rs(int p){ return p*2+1; }

inline 可以有效防止无需入栈的信息入栈,节省时间和空间。

维护线段树采用push_up函数,功能是维护父子节点之间的逻辑关系,实际上是在合并两个子节点的信息:

1
2
3
4
5
// 向上维护区间操作
inline void push_up(int p)
{
ans[p]=ans[ls(p)]+ans[rs(p)];
}

再建树的递归中,需要先去整合子节点的信息,再向它们的祖先回溯整合之后的信息。建树的build函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
void build(int p, int l, int r)
{
tag[p]=0;
if(l==r){ // 如果区间的左右相同,则必为叶子节点,
ans[p]=a[l];
return;
}
int mid=(l+r)/2;
build(ls(p), l, mid);
build(rs(p), mid+1, r);
push_up(p); // 由于是通过子节点维护父节点,push_up需要再最后进行
}

区间修改

在区间修改中,引入懒标记,在下方代码在中记为tag,作用是记录每次、每个节点要更新的值,从而降低区间更新的时间复杂度。

下方的f函数用于对线段树的某个节点所代表的区间进行更新操作,p为节点编号,tag[p]

1
2
3
4
5
inline void f(int p, int l, int r, int k)
{
tag[p]=tag[p]+k;
ans[p]=ans[p]+k*(r-l+1);
}

向下传递信息时使用push_down来维护线段树:

1
2
3
4
5
6
7
8
inline void push_down(int p, int l, int r)
{
int mid=(l+r)/2;
// 每次push_down都需要更新两个子节点的数值
f(ls(p), l, mid, tag[p]);
f(rs(p), mid+1, r, tag[p]);
tag[p]=0;
}

下方是update函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// nl、nr为需要修改的区间;l,r,p为当前节点所存储的区间以及节点的编号
inline void update(int nl, int nr, int l, int r, int p, int k)
{
if(nl<=l&&r<=nr)
{
ans[p]+=k*(r-l+1);
tag[p]+=k;
return;
}
push_down(p, l, r);
int mid=(l+r)/2;
if(nl<=mid) update(nl, nr, l, mid, ls(p), k);
if(nr>mid) update(nl, nr, mid+1, r,rs(p), k);
push_up(p);
}

区间查询

结构与区间更新类似,利用到了分块的思想

1
2
3
4
5
6
7
8
9
10
int query(int q_x, int q_y, int l, int r, int p)
{
int res=0;
if(q_x<=l && r<=q_y) return ans[p];
int mid=(l+r)/2;
push_down(p,l,r);
if(q_x<=mid) res+=query(q_x, q_y, l, mid, ls(p));
if(q_y>mid) res+=query(q_x, q_y, mid+1, r, rs(p));
return res;
}

完整代码示例

模板:洛谷 P3372【模板】线段树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include<bits/stdc++.h>
#define int long long
#define MAXN 1000001
using namespace std;

int n,m;
int ans[MAXN*4], a[MAXN], tag[MAXN*4];
inline int ls(int p){ return p*2; }
inline int rs(int p){ return p*2+1; }

inline void push_up(int p)
{
ans[p]=ans[ls(p)]+ans[rs(p)];
}

void build(int p, int l, int r)
{
tag[p]=0;
if(l==r){
ans[p]=a[l];
return;
}
int mid=(l+r)/2;
build(ls(p), l, mid);
build(rs(p), mid+1, r);
push_up(p);
}

inline void f(int p, int l, int r, int k)
{
tag[p]=tag[p]+k;
ans[p]=ans[p]+k*(r-l+1);
}

inline void push_down(int p, int l, int r)
{
int mid=(l+r)/2;
f(ls(p), l, mid, tag[p]);
f(rs(p), mid+1, r, tag[p]);
tag[p]=0;
}

inline void update(int nl, int nr, int l, int r, int p, int k)
{
if(nl<=l&&r<=nr)
{
ans[p]+=k*(r-l+1);
tag[p]+=k;
return;
}
push_down(p, l, r);
int mid=(l+r)/2;
if(nl<=mid) update(nl, nr, l, mid, ls(p), k);
if(nr>mid) update(nl, nr, mid+1, r,rs(p), k);
push_up(p);
}

int query(int q_x, int q_y, int l, int r, int p)
{
int res=0;
if(q_x<=l && r<=q_y) return ans[p];
int mid=(l+r)/2;
push_down(p,l,r);
if(q_x<=mid) res+=query(q_x, q_y, l, mid, ls(p));
if(q_y>mid) res+=query(q_x, q_y, mid+1, r, rs(p));
return res;
}

signed main()
{
int ttt,b,c,d,e,f;
cin>>n>>m;
for(int i=1;i<=n;i++)
{
cin>>a[i];
}
build(1,1,n);
while(m--)
{
cin>>ttt;
if(ttt==1)
{
cin>>b>>c>>d;
update(b,c,1,n,1,d);
}
else
{
cin>>e>>f;
cout<<query(e,f,1,n,1)<<endl;
}
}
return 0;
}