【算法】树状数组维护总结
本文仅对树状数组的使用作一个总结,并非讲解。
这里的操作都对长度为 n n n 的数组 a a a 进行操作。
单点修改,区间查询
-
暴力做法:
- 修改: a [ x ] = y a[x]=y a[x]=y,时间复杂度为 O ( 1 ) O(1) O(1)
- 查询: ∑ i = l r a [ i ] \sum\limits_{i=l}^ra[i] i=l∑ra[i] ,时间复杂度为 O ( n ) O(n) O(n)
-
树状数组:
t r tr tr 数组 对 a a a 数组进行维护-
修改:
void update(int x, int y) { while (x <= n) tr[x] += y, x += (x & (-x)); } update(x, y);
时间复杂度为 O ( log n ) O(\log n) O(logn)
-
查询:
int query(int x) { int ans = 0; while (x >= 1) ans += tr[x], x -= (x & (-x)); return ans; }
时间复杂度为 O ( log n ) O(\log n) O(logn)
-
区间修改,单点查询
-
暴力做法:
- 修改: a [ l ] = a [ l ] + x , ⋯ , a [ r ] = a [ r ] + x a[l]=a[l]+x,\cdots,a[r]=a[r]+x a[l]=a[l]+x,⋯,a[r]=a[r]+x,时间复杂度为 O ( n ) O(n) O(n)
- 查询: a [ x ] a[x] a[x],时间复杂度为 O ( 1 ) O(1) O(1)
-
树状数组:
b b b 数组是 a a a 的差分数组。 t r tr tr 数组对 b b b 数组进行维护- 修改:
时间复杂度为 O ( log n ) O(\log n) O(logn)void update(int x, int y) { while (x <= n) tr[x] += y, x += (x & (-x)); } update(l, x); update(r + 1, -x);
- 查询:
时间复杂度为 O ( log n ) O(\log n) O(logn)int query(int x) { int ans = 0; while (x >= 1) ans += tr[x], x -= (x & (-x)); return ans; } query(x);
- 修改:
区间修改,区间查询
区间查询的的公式为: ∑ i = l r a [ i ] \sum\limits_{i=l}^ra[i] i=l∑ra[i],我们先考虑如何求 ∑ i = 1 p a [ i ] \sum\limits_{i=1}^p a[i] i=1∑pa[i]
问题转换为如何去求解这个公式,暴力情况下,求 a [ i ] a[i] a[i] 是 O ( log n ) O(\log n) O(logn),总时间复杂度为 O ( n log n ) O(n\log n) O(nlogn)。
但是我们可以拆分这个公式:
∑
i
=
1
p
a
[
i
]
=
∑
i
=
1
p
∑
j
=
1
i
b
[
j
]
=
p
×
b
[
1
]
+
(
p
−
1
)
×
b
[
2
]
+
⋯
+
1
×
b
[
p
]
=
(
(
p
+
1
)
×
∑
i
=
1
p
b
[
i
]
)
−
(
1
×
b
[
1
]
+
2
×
b
[
2
]
+
⋯
+
p
×
b
[
p
]
)
\begin{aligned} \sum\limits_{i=1}^p a[i] &=\sum\limits_{i=1}^p \sum\limits_{j=1}^ib[j] \\ &= p \times b[1]+(p-1)\times b[2]+\cdots+1\times b[p] \\ &= ((p+1)\times \sum\limits_{i=1}^p b[i])-(1\times b[1]+2\times b[2]+\cdots+p\times b[p])\\ \end{aligned}
i=1∑pa[i]=i=1∑pj=1∑ib[j]=p×b[1]+(p−1)×b[2]+⋯+1×b[p]=((p+1)×i=1∑pb[i])−(1×b[1]+2×b[2]+⋯+p×b[p])
所以我们再用一个额外的树状数组去维护 i × b [ i ] i\times b[i] i×b[i] 即可。
// 区间修改[x, n]
const int N =100010;
int tr1[N], tr2[N];
void update(int x, int y) {
int val2 = x * y;
while (x <= n) {
tr1[x] += y;
tr2[x] += val2;
x += (x & (-x));
}
}
update(l, d); // a[l, n] += d;
update(r + 1, -d); // a[r + 1, n] -= d
// 区间查询(1, x)
int query(int x) {
int p = x;
int val1 = 0, val2 = 0;
while (x >= 1) {
val1 += tr1[x];
val2 += tr2[x];
x -= (x & -x);
}
return (p + 1) * val1 - val2;
}
// 查询 a[l, r]
query(1, r) - query(1, l - 1);
时间复杂度为 O ( log n ) O(\log n) O(logn)