模板分享:线段树(2)
Code
先放代码:
#include<iostream>
#include<vector>
using namespace std;
template<class Info, class Tag>
struct lazy_segment{
private:
#define ls (u * 2 + 1)
#define rs (u * 2 + 2)
struct Node{
int l, r;
Info info;
Tag tag;
};
vector<Node> tr;
public:
using info_type = Info;
using tag_type = Tag;
lazy_segment() {}
lazy_segment(int n, Info v = Info()){
vector<Info> a(n, v);
init(a);
}
template<class T>
lazy_segment(const vector<T> &a){
init(a);
}
template<class T>
void init(const vector<T> &a){
int n = a.size();
tr.resize(n << 2);
build(0, 0, n - 1, a);
}
private:
void pushup(int u){
tr[u].info = tr[ls].info + tr[rs].info;
}
void apply(int u, const Tag &v){
tr[u].info.apply(v);
tr[u].tag.apply(v);
}
void pushdown(int u){
apply(ls, tr[u].tag);
apply(rs, tr[u].tag);
tr[u].tag = Tag();
}
template<class T>
void build(int u, int l, int r, const vector<T> &a){
tr[u].l = l;
tr[u].r = r;
if(l == r){
tr[u].info = a[l];
return;
}
int mid = (l + r) >> 1;
build(ls, l, mid, a);
build(rs, mid + 1, r, a);
pushup(u);
}
void modify(int u, int x, const Info &v){
if(tr[u].l == tr[u].r){
tr[u].info = v;
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if(x <= mid) modify(ls, x, v);
else modify(rs, x, v);
pushup(u);
}
void add(int u, int x, const Info &v){
if(tr[u].l == tr[u].r){
tr[u].info = tr[u].info + v;
return;
}
int mid = tr[u].l + tr[u].r >> 1;
if(x <= mid) add(ls, x, v);
else add(rs, x, v);
pushup(u);
}
Info query(int u, int l, int r){
if(l <= tr[u].l && r >= tr[u].r) return tr[u].info;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if(r <= mid) return query(ls, l, r);
else if(l > mid) return query(rs, l, r);
return query(ls, l, mid) + query(rs, mid + 1, r);
}
Info get(int u, int x){
if(tr[u].l == tr[u].r) return tr[u].info;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if(x <= mid) return get(ls, x);
else return get(rs, x);
}
void apply(int u, int l, int r, const Tag &v){
if(l <= tr[u].l && r >= tr[u].r){
apply(u, v);
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid) apply(ls, l, r, v);
if(r > mid) apply(rs, l, r, v);
pushup(u);
}
public:
void modify(int x, const Info &v){
modify(0, x, v);
}
void add(int x, const Info &v){
add(0, x, v);
}
void apply(int l, int r, const Tag &v){
apply(0, l, r, v);
}
Info query(int l, int r){
return query(0, l, r);
}
Info get(int x){
return get(0, x);
}
#undef ls
#undef rs
};
Info
\texttt{Info}
Info 需要支持
+
+
+ (合并)和
apply
(
t
)
\operatorname{apply}(t)
apply(t) (用标记
t
t
t 更新数据)操作.
Tag
\texttt{Tag}
Tag 需要支持
apply
(
t
)
\operatorname{apply}(t)
apply(t) (将自己叠加上标记
t
t
t)操作.
大部分函数同 线段树(1).
Apply
void lazy_segment<Info,Tag>::apply(int l, int r, const Tag &v)
对每个
i
∈
[
l
,
r
]
i \in [l,r]
i∈[l,r] 执行
a
i
←
a
i
+
t
a_i \gets a_i+t
ai←ai+t(打上标记
t
t
t).
0
≤
l
≤
r
<
n
0 \le l \le r < n
0≤l≤r<n
Example
给定序列 a = ( a 1 , a 2 , ⋯ , a n ) a=(a_1,a_2,\cdots,a_n) a=(a1,a2,⋯,an),有 m m m 个操作分三种:
- assign ( l , r , v ) \operatorname{assign}(l,r,v) assign(l,r,v):对每个 i ∈ [ l , r ] i \in [l,r] i∈[l,r] 执行 a i ← v a_i \gets v ai←v.
- add ( l , r , v ) \operatorname{add}(l,r,v) add(l,r,v):对每个 i ∈ [ l , r ] i \in [l,r] i∈[l,r] 执行 a i ← a i + v a_i \gets a_i+v ai←ai+v.
- query ( l , r ) \operatorname{query}(l,r) query(l,r):求 max i = l r a i \max\limits_{i=l}^r a_i i=lmaxrai.
struct Tag {
int tag1, tag2;
bool used;
Tag(int _tag1 = 0, int _tag2 = 0, bool _used = false):
tag1(_tag1), tag2(_tag2), used(_used) {}
void apply(Tag t) {
if (t.used) {
tag1 = t.tag1;
tag2 = t.tag2;
used = true;
}
else tag2 += t.tag2;
}
};
struct Info {
int mx;
Info(int _mx = 0): mx(_mx) {}
void apply(Tag t, int len) {
if (t.used) mx = t.tag1 + t.tag2;
else mx += t.tag2;
}
};
Info operator+(const Info& lhs, const Info& rhs) {
return Info(max(lhs.mx, rhs.mx));
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
int n, m;
cin >> n >> m;
vector<Info> a(n);
for (auto &i: a) cin >> i.mx;
lazy_segment<Info, Tag> seg(a);
for (int i = 0, op, l, r, x; i < m; i++) {
cin >> op >> l >> r;
l--, r--;
if (op == 1) {
cin >> x;
seg.apply(l, r, Tag(x, 0, true));
}
if (op == 2) {
cin >> x;
seg.apply(l, r, Tag(0, x, false));
}
if (op == 3) cout << seg.query(l, r).mx << endl;
}
return 0;
}