본문 바로가기
알고리즘/백준 & swacademy

BOJ 7469 - K번째 수 ( 머지소트트리, pst )

by sun__ 2020. 6. 10.

https://www.acmicpc.net/problem/7469

 

두가지 방법으로 가능

머지소트트리, 퍼시스턴트세그트리

 

pst론 구간 k번째 수를 효과적으로 셀 수 있다.

(세그먼트트리에서 전체 k번째 수를 세는 것이랑 비슷함)

 

<문제설명>

최대 1e5개의 정수 배열이 주어진다. 배열의 원소는 절대값이 1e9보다 작은 정수이다.

 

다음 쿼리를 처리

s,e,k : 구간[s,e]의 k번째 수를 출력

 

<풀이1> - 머지소트트리

머지소트트리를 만든 후,

 

1. 쿼리마다 -1e9~1e9사이에서 k번째 수가 될 수 있는 최소값을 구한다.

구간에서 x미만의 수의 개수가 k-1개이상인 수 중 최소값을 구해주면 된다.

(이분탐색 시 음수범위에 주의)

 

2. 구간에서 x이상 최소값을 찾아 출력해준다.

 

<코드1>

const int MAX = 2e5 + 4;
const int SMAX = (1 << 18);

int n, qq;
vector<int> seg[SMAX];

void construct() {
	for (int i = SMAX / 2 - 1; i >= 1; i--) {
		for (int j : seg[i * 2]) seg[i].push_back(j);
		for (int j : seg[i * 2 + 1]) seg[i].push_back(j);
		sort(seg[i].begin(), seg[i].end());
	}
}

int val1(int s, int e, int x, int i = 1, int ns = 0, int ne = SMAX / 2 - 1) {
	if (e < ns || ne < s) return 0;
	if (s <= ns && ne <= e) 
		return lower_bound(seg[i].begin(), seg[i].end(), x) - seg[i].begin();
	
	int md = (ns + ne)/2;
	return val1(s, e, x, i * 2, ns, md) + val1(s, e, x, i * 2 + 1, md + 1, ne);
}

int val2(int s, int e, int x, int i = 1, int ns = 0, int ne = SMAX / 2 - 1) {
	if (e < ns || ne < s) return 1e9;
	if (s <= ns && ne <= e) {
		auto it = lower_bound(seg[i].begin(), seg[i].end(), x);
		if (it == seg[i].end()) return 1e9;
		return *it;
	}
	int md = (ns + ne)/2;
	return min(val2(s, e, x, i * 2, ns, md), val2(s, e, x, i * 2 + 1, md + 1, ne));
}

int main() {
	FAST;
	cin >> n >> qq;
	for (int i = 0, x; i < n; i++) {
		cin >> x;
		seg[SMAX / 2 + i + 1].push_back(x);
	}
	construct();

	for (int i = 0, s, e, k; i < qq; i++) {
		cin >> s >> e >> k;
		int lo = -1e9, hi = 1e9;
		while (lo+2 < hi) { 
			int x = (lo + hi) / 2;
			if (val1(s, e, x) >= k - 1) hi = x;
			else lo = x+1;
		}
		for (int j = lo; j <= hi; j++) if (val1(s, e, j) >= k - 1) {
			lo = j;
			break;
		}
		cout << val2(s, e, lo) << '\n';
	}
}

 

 

<코드2>

typedef pair<int, int> P;
const int MAX = 1e5 + 4;

struct node {
	node* l = 0, *r = 0;
	int x = 0;
	node(int x) :x(x){ }
	node(node* l, node* r) :l(l), r(r) {
		if (l) x += l->x;
		if (r) x += r->x;
	}
};

vector<node*> pst;

int n, query;
vector<int> va;
int a[MAX], aa[MAX];

node* construct(int ns = 0, int ne = MAX) {
	if (ns == ne) return new node(0);
	int md = (ns + ne) / 2;
	return new node(construct(ns, md), construct(md + 1, ne));
}

node* add(node* i, int idx, int ns = 0, int ne = MAX) {
	if (ns == ne) return new node(i->x + 1);
	int md = (ns + ne) / 2;
	if (idx <= md)
		return new node(add(i->l, idx, ns, md), i->r);
	else
		return new node(i->l, add(i->r, idx, md + 1, ne));
}

int kth(node* s, node* e, int k, int ns = 0, int ne = MAX) {
	if (ns == ne) return ns;

	int delta = e->l->x - s->l->x;
	int md = (ns + ne) / 2;
	if (k <= delta)
		return kth(s->l, e->l, k, ns, md);
	else
		return kth(s->r, e->r, k - delta, md + 1, ne);
}

int kth(int s, int e, int k) {
	return kth(pst[s - 1], pst[e], k);
}


int main() {
	FAST; cin >> n >> query;
	for (int i = 1; i <= n; i++) {
		cin >> a[i];
		va.push_back(a[i]);
	}
	sort(va.begin(), va.end());

	pst.push_back(construct());
	for (int i = 1; i <= n; i++) {//좌표압축
		aa[i] = lower_bound(va.begin(), va.end(), a[i]) - va.begin();
		pst.push_back(pst.back());
		pst[i] = add(pst[i], aa[i]);
	}
	while (pst.size() != MAX) pst.push_back(pst.back());

	while (query--) {
		int i, j, k; cin >> i >> j >> k;
		cout << va[kth(i,j,k)] << '\n';
	}

}