본문 바로가기
알고리즘/백준 & swacademy

BOJ 15977 - 조화로운 행렬 (분할정복, 세그트리)

by sun__ 2020. 5. 30.

https://www.acmicpc.net/problem/15977

정해는 다차원 세그트리나 이분탐색+set이다.

 

https://codeforces.com/blog/entry/43319

koosaga님의 아이디어를 구현.

 

<문제설명>

m=2인 경우 정렬 후 LIS

 

m=3인 경우 정렬 후 LIS on pair (x,y값이 모두 증가하는 형태)

 

<풀이>

링크의 설명이 너무 명료하므로 큰 그림은 생략. 

 

[m+1,e] 범위의 dp를 [s,m] 범위의 dp 값으로 초기화 하는 과정만 기록

 

세그트리는 y를 인덱스로 하고 dp값을 값으로 갖는다.

1. 하나의 벡터에 {x,y,idx}를 [s,e]범위 모두 넣어준 후 x값 기준으로 정렬한다.

2. idx가 md이하인 경우 세그트리를 업데이트 해준다.

3. idx가 md초과인 경우 세그트리의 정보를 이용하여 dp를 업데이트해준다.

4. 세그트리를 초기화해준다.

 

 

<코드>

const int MAX = 2e5 + 4;
const int SMAX = (1 << 20);

int n, m, a[3][MAX], inp[3][MAX], dp[MAX];
int seg[SMAX];
vector<T> b;
vector<P> c;

void update(int i, int x) {
	i += SMAX / 2;
	seg[i] = x;
	while (i > 1) {
		i /= 2;
		seg[i] = max(seg[i * 2], seg[i * 2 + 1]);
	}
}

int val(int s, int e, int i, int ns, int ne) {
	if (e < ns || ne < s) return 0;
	if (s <= ns && ne <= e) return seg[i];
	int md = (ns + ne) / 2;
	return max(val(s, e, i * 2, ns, md), val(s, e, i * 2 + 1, md + 1, ne));
}
int val(int e) {return val(0, e, 1, 0, SMAX / 2 - 1);}

int solve(int s, int e) {
	if (s == e) return 0;
	
	int ret = 0;
	int md = (s + e) / 2;
	ret = max(ret, solve(s, md));

	vector<T> temp;
	for (int i = s; i <= e; i++)
		temp.push_back({ get<1>(b[i]), get<2>(b[i]), i });
	sort(temp.begin(), temp.end());

	for (T tt : temp) {
		int x, y, idx; tie(x, y, idx) = tt;
		dp[idx] = max(dp[idx], 1);
		if (idx > md) {
			dp[idx] = max(dp[idx], val(y) + 1);
			ret = max(ret, dp[idx]);
		}
		else {
			update(y, dp[idx]);
		}
	}
	for (T tt : temp) {
		int x, y, idx; tie(x, y, idx) = tt;
		if (idx <= md) update(y, 0);
	}

	ret = max(ret, solve(md + 1, e));
	return ret;
}

int main() {
	FAST; cin >> m >> n;
	for (int i = 0; i < m; i++) for (int j = 0; j < n; j++) {
		cin >> inp[i][j];
		a[i][j] = inp[i][j];
	}
	for (int i = 0; i < m; i++) sort(inp[i], inp[i] + n);
	for (int i = 0; i < m; i++) for (int j = 0; j < n; j++) 
		a[i][j] = lower_bound(inp[i], inp[i] + n, a[i][j]) - inp[i];
	
	
	if (m == 3) {
		for (int i = 0; i < n; i++) 
			b.push_back({ a[0][i], a[1][i], a[2][i] });
		
		sort(b.begin(), b.end());

		int ans = solve(0, n - 1);
		cout << ans << '\n';
	}
	else if (m == 2) {
		for (int i = 0; i < n; i++) 
			c.push_back({ a[0][i], a[1][i] });
		
		sort(c.begin(), c.end());
		vector<int> v(1,-1);
		for (int i = 0,x,y; i < n; i++) {
			tie(x, y) = c[i];
			if (v.back() < y) v.push_back(y);
			else *lower_bound(v.begin(), v.end(), y) = y;
		}
		cout << v.size() - 1 << '\n';
	}
}