0%

树状数组与线段树

本篇讲一些其他博客没有展开讲的部分。

Q:什么是lowbit函数?

A:lowbit可以取出一个数字的最低位1,如lowbit(10102{1010}_2) = 00102{0010}_2

lowbit(5) = 1,lowbit(6) = 2

Q:为什么树状数组的划分使用lowbit?

树状数组和线段树一样,需要将一个区域[1,n]分成多段

线段树的划分是:二分,需要4倍内存(粗略4倍)

而树状数组的划分是:二进制分解

假设n=2i1+2i2++2imn = 2^{i_1}+2^{i_2}+…+2^{i_m}

那么对于[1,n][1,n]的区域,划分的区间为:

[1,2i1][1,2^{i_1}]

[2i1+1,2i1+2i2][2^{i_1}+1,2^{i_1}+2^{i_2}]

[2i1+2i2++2im1+1,2i1+2i2++2im],[2^{i_1}+2^{i_2}+…+2^{i_{m-1}}+1,2^{i_1}+2^{i_2}+…+2^{i_{m}}],

x=7=20+21+22x = 7 = 2^0+2^1+2^2

1
2
3
4
5
6
7
8
9
10
x = 7;
lowbit(x); // lowbit(x) = 1
x-=lowbit(x); // x = 6
lowbit(x); // lowbit(x) = 2
x-=lowbit(x); // x = 4
lowbit(x); //lowbit(x) = 4
x-=lowbit(x); // x = 0
[1,7]被划分为[1,4],[5,6],[7,7]

树状数组c[x]保存的则是:[x-lowbit(x)+1,x]区间中所有数的和
1
2
3
4
5
6
7
8
9
10
11
12
13
14
int ask(int x) {
int ans = 0;
while(x) {
ans+=c[x];
x-=lowbit(x);
// x = 7->[7,7] , c[7] = a[7]
// x = 6->[5,6] , c[6] = a[5]+a[6]
// x = 4->[1,4] , c[4] = a[1]+a[2]+a[3]+a[4]
// x = 0->end
//由此看出,可以保证不重合的划分
// ans=c[7]+c[6]+c[4]
}
return ans;
}

同时,由以上的介绍可以推出,对于c[x]节点,它的父亲结点应该是c[x+lowbit(x)]

那么更新时则为

1
2
3
4
5
6
7
8
void add(int x,int k) {
while(x<=n) {
c[x]+=k;
x+=lowbit(x);
// Q: 为什么是x = x+lowbit(x)?
// A: 因为它需要更新它的父节点,使得包含它的区间得以被更新
}
}

以上的add(int x,int k)和ask(int x)实现的功能是:单点更新、区间查询

那么如果要实现单点查询和区间更新呢?考虑利用差分的思想,请看代码:

PS:单点查询为query(x),而不是query(x)-query(x-1)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e6 + 100;

int n;
int a[maxn];
int c[maxn];


int lowbit(int x) {
return x & (-x);
}

void add(int i,int k) { // 单点更新
while(i<=n) {
c[i] += k;
i += lowbit(i);
}
}
int query(int x) { // [1,x] 区间查询
int ans = 0;
while(x) {
ans += c[x];
x -= lowbit(x);
}
return ans;
}


int main() {
int m;
cin >> n >> m;
for (int i = 1;i<=n;i++) {
cin >> a[i];
}
for (int i = 1;i<=m;i++) {
char c;
cin >> c;
if(c == 'C') {
int l, r, d;
cin >> l >> r >> d;
// 区间更新
add(l, d);
add(r + 1, -d);
} else {
int x;cin>>x;
// 单点查询
cout << a[x] + query(x)<<endl;
}

}
}

Codeforces1324DPair of Topics
https://codeforces.com/contest/1324/problem/D

给定a,ba,b序列,求有多少对ibi+bjib_i+b_j

解法1:按照值从大到小插入

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e6 + 5;
int a[maxn];
int b[maxn];
int c[maxn];
struct Node {
int v, index;
bool operator<(const Node& b) const {
if(v==b.v) {
return index <b.index;//注意这里!!
}
return v < b.v; //从小到大排序
}
} node[maxn];
int n;
void add(int i) {
while(i<=n*2)
{
c[i]++;
i+=i&(-i);
}
}
int sum(int i) {
int res=0;
while(i>0)
{
res+=c[i];
i-=i&(-i);
}
return res;
}
vector<int> v;
int main() {
scanf("%d", &n);
for(int i=1;i<=n;i++) {
scanf("%d", &a[i]);
}
for(int i=1;i<=n;i++) {
scanf("%d", &b[i]);
}
for (int i = 1; i <= n;i++) {
node[i].v = a[i] - b[i];
node[i].index = i;
node[i + n].v = b[i] - a[i];
node[i + n].index = i + n;
}
long long ans = 0;
sort(node + 1, node + 1 + 2 * n);
for (int i = 2*n; i >=1;i--) {
int id = node[i].index;
int v = node[i].v;
if(id<=n) {
add(node[i].index);
continue;
}
if(v<0)
ans--;
ans += sum(id - n);
}

cout<<ans<<endl;
}

解法2:按照序列顺序从小到大

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include<bits/stdc++.h>
using namespace std;
const int maxn = 2e5 + 100;
int a[maxn], b[maxn], c[maxn * 2];
int lowbit(int x) {
return x & (-x);
}

void add(int i,int d) {
while(i<maxn*2) {
c[i] += d;
i += lowbit(i);
}
}

int query(int i) {
int ret = 0;
while(i>0) {
ret += c[i];
i -= lowbit(i);
}
return ret;
}

vector<int> v;

int main() {
int n;
scanf("%d", &n);
for(int i=1;i<=n;i++) {
scanf("%d", &a[i]);
}
for (int i = 1; i <= n;i++) {
scanf("%d", &b[i]);
}
for (int i = 1; i <= n;i++) {
v.push_back(a[i] - b[i]);
v.push_back(b[i] - a[i]);
}
long long ans = 0;
sort(v.begin(), v.end());
for (int i = 1;i<=n;i++) {
int x = b[i] - a[i];
int y = a[i] - b[i];
int id = upper_bound(v.begin(), v.end(), x) - v.begin() + 1;
int id2 = upper_bound(v.begin(), v.end(), y) - v.begin() + 1;
ans += (i - 1 - query(id));
add(id2, 1);
}
cout << ans << endl;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include<bits/stdc++.h>
using namespace std;
const int maxn = 2e5 + 100;
vector<int> v;
int a[maxn],b[maxn];

int main() {
int n;
scanf("%d", &n);
for (int i = 1; i <= n;i++) {
scanf("%d", &a[i]);
}
for (int i = 1; i <= n;i++) {
scanf("%d", &b[i]);
v.push_back(a[i] - b[i]);
}
sort(v.begin(), v.end());
long long ans = 0;
for (int i = 1; i <= n;i++) {
int id = upper_bound(v.begin(), v.end(), b[i] - a[i]) - v.begin();
ans += (n - id);
if(a[i]-b[i]>b[i]-a[i])
ans--;
}
ans /= 2;
cout << ans << endl;
}

线段树
模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
struct SegTree {
#define lson l,mid,rt<<1
#define rson mid+1,r,rt<<1|1
ll a[maxn << 2];
int tag[maxn << 2];
void pushup(int rt) {
a[rt] = min(a[rt << 1], a[rt << 1 | 1]);
}
void push(int rt,int x) {
a[rt] += x;
tag[rt] += x;
}
void pushdown(int rt) {
if(!tag[rt])
return;
push(rt << 1, tag[rt]);
push(rt << 1 | 1, tag[rt]);
tag[rt] = 0;
}
void build(int l = 1,int r = n,int rt = 1) {
if(l==r) {
a[rt] = 0;
return;
}
int mid = (l + r) >> 1;
build(lson), build(rson);
pushup(rt);
}
void update(int L, int R, int x, int l = 1, int r = n,int rt = 1) {
if(L<=l&&r<=R) {
push(rt, x);
return;
}
int mid = (l + r) >> 1;
pushdown(rt);
if(L<=m)
update(L, R, x, lson);
if(R>m)
update(L, R, x, rson);
pushup(rt);
}
ll query(int L, int R, int l = 1, int r = n,int rt = 1) {
if(L<=l&&r<=R) {
return a[rt];
}
int m = (l + r) >> 1;
pushdown(rt);
if(R<=m)
return query(L, R, lson);
if(L>m)
return query(L, R, rson);
return min(query(L, R, lson), query(L, R, rson));
}
} f;