并集运算的线段树维护方式
题目
给定 n n n 个区间 [ l , r ] [l,r] [l,r],满足 1 ≤ l ≤ r ≤ n 1 \le l \le r \le n 1≤l≤r≤n,求它们的并集的长度。
线段树做法
设计一个线段树维护以下操作:
- 区间修改。
- 单点查询。
其中线段树单点 i i i 代表的区间为 [ l , r ) [l,r) [l,r) ,对于每个单点使用 set 或 map 维护。
如果值域过大,使用离散化。
代码
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <cstring>
#include <queue>
#include <ctime>
#include <random>
#include <set>
#include <map>
#include <bitset>
#define int long long
using namespace std;
const int INF = 0x3f3f3f3f;
inline int read()
{
int w = 0, f = 1;
char ch = getchar();
while (ch < '0' || ch > '9')
{
if (ch == '-') f = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9')
{
w = (w << 1) + (w << 3) + (ch ^ 48);
ch = getchar();
}
return w * f;
}
inline void write(int x)
{
if (x < 0)
{
putchar('-');
x = -x;
}
if (x > 9) write(x / 10);
putchar(x % 10 + '0');
}
const int maxn = 8e5 + 5;
int sl[maxn], sr[maxn], nl[maxn], nr[maxn], val[maxn * 2];
int cntn;
struct node
{
int l, r;
int val, tag;
} tr[maxn * 4];
void print(node x)
{
cout << x.l << " " << x.r << " " << x.val << " " << x.tag << endl;
}
int ls(int p)
{
return p << 1;
}
int rs(int p)
{
return p << 1 | 1;
}
void pushdown(int p)
{
int tag = tr[p].tag;
if (tr[p].tag != -1)
{
tr[ls(p)].val = tag;
tr[rs(p)].val = tag;
tr[ls(p)].tag = tag;
tr[rs(p)].tag = tag;
tr[p].tag = -1;
}
}
void build(int p, int l, int r)
{
tr[p].l = l, tr[p].r = r;
if (l == r)
{
tr[p].tag = -1;
return;
}
int mid = l + r >> 1;
build(ls(p), l, mid);
build(rs(p), mid + 1, r);
}
void update(int p, int l, int r, int x)
{
if (l <= tr[p].l && tr[p].r <= r)
{
tr[p].val = x;
tr[p].tag = x;
return;
}
pushdown(p);
int mid = (tr[p].l + tr[p].r) / 2;
if (mid >= l) update(ls(p), l, r, x);
if (mid < r) update(rs(p), l, r, x);
}
int query(int p, int x)
{
if (tr[p].l == x && tr[p].r == x)
{
return tr[p].val;
}
pushdown(p);
int mid = (tr[p].l + tr[p].r) / 2;
if (mid >= x) return query(ls(p), x);
else return query(rs(p), x);
}
using namespace std;
int n, m;
struct sss
{
int l, r;
} t[200005];
map<int, int> d;
vector<int> ans;
int mp[maxn * 2];
signed main()
{
n = read();
for (int i = 1; i <= n; ++i)
{
t[i].l = sl[i] = read(), t[i].r = sr[i] = read();
++cntn;
val[cntn] = mp[cntn] = sl[i];
++cntn;
val[cntn] = mp[cntn] = sr[i];
}
sort(val + 1, val + cntn + 1);
cntn = unique(val + 1, val + cntn + 1) - val - 1;
for (int i = 1; i <= n; ++i)
{
nl[i] = lower_bound(val + 1, val + cntn + 1, sl[i]) - val;
nr[i] = lower_bound(val + 1, val + cntn + 1, sr[i]) - val;
}
for (int i = 1; i <= cntn; ++i)
{
mp[i] = lower_bound(val + 1, val + cntn + 1, val[i]) - val;
}
build(1, 1, cntn);
map<int, int> dpoint;
for (int i = 1; i <= n; ++i)
{
if (nl[i] == nr[i])
{
dpoint[nl[i]] = 1;
continue;
}
if (nl[i] != nr[i])
{
update(1, nl[i], nr[i] - 1, 1);
dpoint[nr[i]] = 1;
if (dpoint.find(nl[i]) != dpoint.end())
{
dpoint.erase(dpoint.find(nl[i]));
}
}
}
for (int i = 1; i < cntn; ++i)
{
if (query(1, i) == 1)
{
if (dpoint.find(i) != dpoint.end())
{
dpoint.erase(dpoint.find(i));
}
ans.push_back(i);
}
}
if (ans.empty() && dpoint.empty())
{
puts("0");
return 0;
}
else if (ans.empty())
{
cout << dpoint.size() << endl;
return 0;
}
int tot = 0;
for (int i = 0; i < ans.size(); ++i)
{
tot += val[ans[i] + 1 ] - val[ans[i]];
}
cout << tot + dpoint.size() << endl;
return 0;
}
拓展
问:如何求出具体区间?
答: 根据最后一步显然。