Menu

Fast Matrix Operations

post on 17 Nov 2018 about 4330words require 15min
CC BY 4.0 (除特别声明或转载文章外)
如果这篇博客帮助到你,可以请我喝一杯咖啡~

题目链接 对蓝书上的懒标记进行加强,将对区间的操作改成区间线性变换,那么区间修改就是区间先乘 0 在修改,区间加就是区间先乘 1 再加。这样敲反而简单很多。

update: 更新了线段树模板。新的模板省去了 build 操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#include <bits/stdc++.h>
using namespace std;
typedef int ll;
const int NPOS = -1;
struct SegmentTree
{
	struct Val
	{
		int l, r;
		ll min, max, sum;
		Val(int l, int r) : l(l), r(r), min(0), max(0), sum(0) {}
		Val(const Val &lc, const Val &rc) : l(lc.l), r(rc.r), min(std::min(lc.min, rc.min)), max(std::max(lc.max, rc.max)), sum(lc.sum + rc.sum) {}
		void upd(ll mul, ll add) { min = min * mul + add, max = max * mul + add, sum = sum * mul + add * (r - l + 1); }
	};
	struct Node : Val
	{
		int lc, rc;
		ll mul, add;
		Node(Val v) : Val(v), lc(NPOS), rc(NPOS), mul(1), add(0) {}
	};
	vector<Node> v;
	SegmentTree(int l, int r) : v{Val(l, r)} { v.reserve(r - l + 9 << 1); }
	void down(int rt)
	{
		int m = (v[rt].l + v[rt].r) / 2;
		if (v[rt].lc == NPOS)
			v[rt].lc = v.size(), v.push_back(Val(v[rt].l, m));
		if (v[rt].rc == NPOS)
			v[rt].rc = v.size(), v.push_back(Val(m + 1, v[rt].r));
		upd(v[v[rt].lc].l, v[v[rt].lc].r, v[rt].mul, v[rt].add, v[rt].lc);
		upd(v[v[rt].rc].l, v[v[rt].rc].r, v[rt].mul, v[rt].add, v[rt].rc);
		v[rt].mul = 1;
		v[rt].add = 0;
	}
	void upd(int l, int r, ll mul, ll add, int rt = 0)
	{
		if (l <= v[rt].l && v[rt].r <= r)
			return v[rt].mul *= mul, v[rt].add = v[rt].add * mul + add, v[rt].upd(mul, add);
		down(rt);
		if (r <= v[v[rt].lc].r)
			upd(l, r, mul, add, v[rt].lc);
		else if (l >= v[v[rt].rc].l)
			upd(l, r, mul, add, v[rt].rc);
		else
			upd(l, v[v[rt].lc].r, mul, add, v[rt].lc), upd(v[v[rt].rc].l, r, mul, add, v[rt].rc);
		dynamic_cast<Val &>(v[rt]) = Val(v[v[rt].lc], v[v[rt].rc]);
	}
	Val ask(int l, int r, int rt = 0)
	{
		if (l <= v[rt].l && v[rt].r <= r)
			return v[rt];
		down(rt);
		if (r <= v[v[rt].lc].r)
			return ask(l, r, v[rt].lc);
		if (l >= v[v[rt].rc].l)
			return ask(l, r, v[rt].rc);
		return Val(ask(l, v[v[rt].lc].r, v[rt].lc), ask(v[v[rt].rc].l, r, v[rt].rc));
	}
};
int main()
{
	for (int r, c, m, o, b, e, l; ~scanf("%d%d%d", &r, &c, &m);)
		for (vector<SegmentTree> Tree(r, {1, c}); m--;)
		{
			scanf("%d%d%d%d%d", &o, &b, &l, &e, &r);
			if (o == 3)
			{
				SegmentTree::Val ans(l, r);
				for (ans.min = 1e9, --b; b < e; ++b)
					ans = SegmentTree::Val(ans, Tree[b].ask(l, r));
				printf("%d %d %d\n", ans.sum, ans.min, ans.max);
			}
			else
				for (scanf("%d", &c), --b; b < e; ++b)
					Tree[b].upd(l, r, o == 1, c);
		}
}

原:目测是全网最短解法(Vjudge 上 Length1705,比第二名少了 312,同时性能也比较强)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include<bits/stdc++.h>
using namespace std;
typedef int ll;
struct SegmentTree
{
	struct Node
	{
		int l,r;
		ll mul,add,min,max,sum;
		void set(ll m,ll a)
		{
			mul*=m,add=add*m+a,min=min*m+a,max=max*m+a,sum=sum*m+a*(r-l+1);
		}
		void up(const Node &lc,const Node &rc)
		{
			min=std::min(lc.min,rc.min);
			max=std::max(lc.max,rc.max);
			sum=lc.sum+rc.sum;
		}
		void down(Node &lc,Node &rc)
		{
			lc.set(mul,add),rc.set(mul,add),mul=1,add=0;
		}
	};
	vector<Node> v;
	SegmentTree(int N):v(N<<2)
	{
		build(1,N);
	}
	void build(int l,int r,int rt=1)
	{
		v[rt]= {l,r};
		if(l>=r)return;
		int m=l+r>>1;
		build(l,m,rt<<1),build(m+1,r,rt<<1|1),v[rt].up(v[rt<<1],v[rt<<1|1]);
	}
	void set(int l,int r,ll mul,ll add,int rt=1)
	{
		if(l<=v[rt].l&&v[rt].r<=r)return v[rt].set(mul,add);
		v[rt].down(v[rt<<1],v[rt<<1|1]);
		int m=v[rt].l+v[rt].r>>1;
		if(m>=r)set(l,r,mul,add,rt<<1);
		else if(m<l)set(l,r,mul,add,rt<<1|1);
		else set(l,m,mul,add,rt<<1),set(m+1,r,mul,add,rt<<1|1);
		v[rt].up(v[rt<<1],v[rt<<1|1]);
	}
	Node ask(int l,int r,int rt=1)
	{
		if(l<=v[rt].l&&v[rt].r<=r)return v[rt];
		v[rt].down(v[rt<<1],v[rt<<1|1]);
		int m=v[rt].l+v[rt].r>>1;
		if(m>=r)return ask(l,r,rt<<1);
		if(m<l)return ask(l,r,rt<<1|1);
		return v[0].up(ask(l,m,rt<<1),ask(m+1,r,rt<<1|1)),v[0];
	}
};
int main()
{
	for(int r,c,m,o,b,e,l; ~scanf("%d%d%d",&r,&c,&m);)
		for(vector<SegmentTree> tree(r,c); m--;)
		{
			scanf("%d%d%d%d%d",&o,&b,&l,&e,&r);
			if(o==3)
			{
				SegmentTree::Node ans;
				for(ans.min=1e9,ans.max=ans.sum=0,--b; b<e; ++b)
					ans.up(ans,tree[b].ask(l,r));
				printf("%d %d %d\n",ans.sum,ans.min,ans.max);
			}
			else for(scanf("%d",&c),--b; b<e; ++b)
					tree[b].set(l,r,o==1,c);
		}
}
Loading comments...