본문 바로가기
알고리즘/메모

lazy propagation - 세그먼트 트리 확장

by sun__ 2019. 9. 18.

코드의 상당부분은 kks227(라이)님의 코드를 참고했음을 밝힙니다.

 

기본문제

https://www.acmicpc.net/problem/10999

https://www.acmicpc.net/problem/12844

https://www.acmicpc.net/problem/1395

 

구간 합을 빠르게 구하기 위해서 prefix 배열을 이용

-> update 등의 다수의 쿼리가 발생했을 때도 구간 합을 빠르게 구하기 위해서 segment tree 사용

 

-> 구간 update 등 다수의 쿼리가 발생했을 때 구간 합을 빠르게 구하기 위해 lazy배열과 propagation 사용

 

필요한 것:

 segment tree, lazy 배열

 propagate(node, ns, ne) : add나 val 함수에서의 node에 대해 lazy값이 존재한다면 자식들에게 lazy값 전파하고 자신의 

                                   값(arr[node])에 lazy값을 반영하는 함수.

                                   arr[node] += lazy[node] * (ne-ns+1) 이부분이 문제에 따라 판이하게 바뀐다.

 

구간 연산(add)와 세그먼트트리값 반환(val)의 propagate() 호출 시기를 눈여겨 보자.

#include <cstdio>
#include <algorithm>
using namespace std;
const int MAX = 1 << 21;
typedef long long ll;

ll arr[MAX], lazy[MAX];

void construct() {
	for (int i = MAX / 2 - 1; i > 0; i--)
		arr[i] = arr[i * 2] + arr[i * 2 + 1];
}

void propagate(int node, int ns, int ne) {
	//lazy값이 존재하면 실행
    if(lazy[node]==0) return;
	
	//리프노드가 아니면 자식들에게 미루기
	if (node < MAX / 2) {
		lazy[node * 2] += lazy[node];
		lazy[node * 2+1] += lazy[node];
	}
	//자신에 해당하는 만큼의 값을 더함
	arr[node] += lazy[node] * (ne - ns+1);
	lazy[node] = 0;
}

void add(int s, int e, int k, int node, int ns, int ne) {
	propagate(node, ns, ne);

	if (e < ns || ne < s) return;
	if (s <= ns && ne <= e) {
		//이 노드가 구간에 완전히 포함되면 lazy부여후 propagate
		lazy[node] += k;
		propagate(node, ns, ne);
		return;
	}
	int mid = (ns + ne) / 2;
	add(s, e, k, node * 2, ns, mid);
	add(s, e, k, node * 2 + 1, mid+1, ne);
	arr[node] = arr[node * 2] + arr[node * 2 + 1];
}
void add(int s, int e, int k) { add(s, e, k, 1, 0, MAX/2-1); }
//[ns,ne) -> 초기값 : 1번노드의 범위


ll sum(int s, int e, int node, int ns, int ne) {
	propagate(node, ns, ne);

	if (e < ns || ne < s) return 0;
	if (s <= ns && ne <= e) return arr[node];
	int mid = (ns + ne) / 2;
	return sum(s, e, node * 2, ns, mid) + sum(s, e, node * 2 + 1, mid+1, ne);
}
ll sum(int s, int e) { return sum(s, e, 1, 0, MAX / 2-1); }


int main() {
	int n, m, k;
	scanf("%d %d %d", &n, &m, &k);
	for (int i = 0; i < n; i++) scanf("%lld", &arr[MAX/2+i]);
	construct();

	for (int i = 0; i < m + k; i++) {
		int a, b, c, d;
		scanf("%d", &a);
		if (a == 1) {
			scanf("%d %d %d", &b, &c, &d);
			add(b - 1, c-1, d);
		}
		else {
			scanf("%d %d", &b, &c);
			printf("%lld\n", sum(b - 1, c-1));
		}
	}
}

 

 

다음은 스위치문제(세번째 기본문제)의 핵심 코드

int arr[MAX];
bool lazy[MAX];

void construct() {
	for (int i = MAX / 2 - 1; i > 0; i--)
		arr[i] = arr[i * 2] + arr[i * 2 + 1];
}

void propagate(int node, int ns, int ne) {
	if (lazy[node] != 0) {
		if (node < MAX / 2) {
			lazy[node * 2] ^= 1;
			lazy[node * 2 + 1] ^= 1;
		}
		arr[node] = (ne - ns + 1) - arr[node];

		lazy[node] = 0;
	}
}

void oper(int s, int e, int node, int ns, int ne) {
	propagate(node, ns, ne);

	if (e < ns || ne < s) return;
	if (s <= ns && ne <= e) {
		lazy[node] ^= 1;
		propagate(node, ns, ne);
		return;
	}
	int mid = (ns + ne) / 2;
	oper(s, e, node*2, ns, mid);
	oper(s, e, node*2+1, mid + 1, ne);
	arr[node] = arr[node * 2] + arr[node * 2 + 1];
}
void oper(int s, int e) {oper(s, e, 1, 0, MAX / 2 - 1);}

int val(int s, int e, int node, int ns, int ne) {
	propagate(node, ns, ne);

	if (e < ns || ne < s) return 0;
	if (s <= ns && ne <= e) return arr[node];
	int mid = (ns + ne) / 2;
	return val(s, e, node * 2, ns, mid) + val(s, e, node * 2 + 1, mid + 1, ne);
}
int val(int s, int e) {return val(s, e, 1, 0, MAX / 2 - 1);}

 

올림피아드 어려운 문제도 기본 형태가 크게 바뀌지 않는다.

 * 구간 합 세그먼트 트리에 구간반전연산(구간크기 - 현재켜진 전구 수)

 

 

 

** 세그먼트 트리가 성립하려면 적용되는 연산이 교환법칙을 만족해야 한다.

자세한 설명: https://ingu9981.blog.me/221461732359

'알고리즘 > 메모' 카테고리의 다른 글

행렬의 표현  (0) 2019.09.27
KMP  (0) 2019.09.18
냅색, knapsack  (0) 2019.08.19
기하 - 두 선분 사이의 거리  (0) 2019.08.19
기하1 - 외적, 두 선분의 교차  (0) 2019.08.19