Fork me on GitHub

树状数组简介

树状数组的基本用法

现在有一个数组,我们要对其元素进行$m$次修改和$n$次区间查询。

如果我们用常规数组,那么修改很方便,是$O(m)$,但是区间查询的时候是$O(nq)$,其中$q$是查询的长度;

如果我们用前缀和,那么查询很方便,是$O(n)$,但是修改很麻烦。

所以,对于这种问题,我们用树状数组,其修改和查询的复杂度均为$O(logN)$。

lowbit函数

lowbit函数就是要求给定的一个十进制数的二进制中最后的一个1所对应的十进制,比如lowbit(110) = 2,lowbit(101) = 1,lowbit(1100) = 4。

我们知道:求一个负数的补码可以先找到其对应的正数的补码的最后一个1,然后将这个1之前的数均取反,这样就得到了负数的补码。

所以我们将负数的补码和正数的补码做与运算,这样就只保留了最后一位的1,其他均是0了。

1
2
3
int lowbit(int x){
return x & (-x);
}

树状数组的思想

我们先看图:

1

图中绿色的块与下标一一对应,我们设数组$v$,其中数组$v$如图所示,即表示的是部分区间的和,可以发现,这个区间的长度是$lowbit(x)$。

要这个部分区间的和干什么呢?

我们可以用它来代替前缀和。比如要求前缀和$sum(6)$,那么我们可以这样求:$sum(6) = v[6] + v[4]$,其中$4 = 6 - lowbit(6)$;在比如$sum(11) = v[11] + v[10] + v[8]$。这样,我们就得到了查询前缀和的代码

1
2
3
4
5
6
7
8
int getSum(int x , const vector<int> &v){
int res = 0;
while(x > 0){ // 树状数组的下标从1开始
res += v[x];
x -= lowbit(x);
}
return res;
}

可以看出,这样求前缀和是很快的。

除了代替前缀和,用$v$还有什么好处吗?

从图中可以看出,当数组中一个元素变化时,$v$所需要的修改的地方比前缀和要少。比如修改了第一个位置,那么我们需要变的地方是$v[2]、v[4]、v[8]…..$,即我们只需要修改其祖先的点的值就可以了,我们还可以发现:$2 = 1 + lowbit(1), 4 = 2 + lowbit(2), 8 = 4 + lowbit(4)$,这样,我们就可以得到更新的代码

1
2
3
4
5
6
7
void update(int x , int val , vector<int> & v){
int n = v.size();
while(x < n){
v[x] += val;
x += lowbit(x);
}
}

可以看出,更新的复杂度小了。

完整代码

我们先输入n个数,然后进行m次查询,输入是’F’,则是查询,输入是’U’,则是修改。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include<bits/stdc++.h>
using namespace std;


int lowbit(int x){
return x & (-x);
}

int getSum(int x , const vector<int> &v){
int res = 0;
while(x > 0){
res += v[x];
x -= lowbit(x);
}
return res;
}

void update(int x , int val , vector<int> &v){
int n = v.size();
while(x < n){
v[x] += val;
x += lowbit(x);
}
}

int main()
{
freopen("input.txt","r",stdin);
//freopen("output.txt","w",stdout);
int n;
cin>>n;
vector<int> v(n+1 , 0);
for(int i = 1 ; i<=n ; ++i){
int a;
cin>>a;
update(i , a , v);
//for(auto i: v) cout<<i<<' ';
//cout<<endl;
}
for(auto i: v) cout<<i<<' ';
cout<<endl;
int m;
cin>>m;
for(int i = 0 ; i<m ; ++i){
char c;
cin>>c;
if(c == 'F'){
int a;
cin>>a;
cout<<getSum(a , v)<<endl;
}
else {
int a , b;
cin>>a>>b;
update(a , b - v[a] , v);
for(auto i: v) cout<<i<<' ';
cout<<endl;
}
}

}