본문 바로가기
알고리즘/코드포스

codeforces #608 div2 D - Portals (dp)

by sun__ 2019. 12. 21.

https://codeforces.com/contest/1271/problem/D

 

dp인건 파악했는데 중요한 포인트를 놓치고 식을 너무 복잡하게 짜서 틀린 문제.

 

<문제설명>

성 n개, (성의 방어력 a, 성에서 영입되는 병사 b, 성에 인원 1명을 배치했을 때 받는 점수 c), 성과성을 잇는 간선 m개, 초기 병사 수 k명이 주어질 때 얻을 수 있는 최대 점수를 출력하는 문제다.

 

<풀이>

pos번째 성에 인원을 1명 배치하는 경우는 최대한 늦게 하는 것이 최적이다. 예를들어 3번째 성으로 4번 5번 성에 포탈이 있다면 3번째 성에서 바로 인원을 배치하는 것보단 4번 성에서 포탈을 태워 보내는 경우가 최적이고, 그보다 5번째 성에서 3번째 성으로 인원을 배치하는 것이 최적이란 뜻이다.

 

위 아이디어를 자료구조로 잘 구현해 보면, 일단 입력받은 간선(역방향)에 추가로 u->u 간선을 모든 성에 대해 이어줘서 임시 인접리스트(tadj)를 하나 만든다. 그리고 시점(역방향이니 실제론 종점)이 같은 간선들 중 가장 종점이 큰 요소만 사용해서 실제로 사용할 인접리스트(adj)를 만들어준다. 그리고 모든 인접리스트를 c에 대해 내림차순으로 정렬해준다. (이 성질 때문에 우선순위 큐를 사용해도 된다. 하지만 코드가 약간 복잡해진다.) 이를 코드로 구현해보면 다음과 같다.

 

//역방향 인접리스트 만들기
for (int i = 0,u,v; i < m; i++) {
	cin >> u >> v;
	tadj[v].push_back(u);
}
//종점이 같은 간선들 중 시점이 가장 큰 간선만 adj에 따로 뽑아두기
for (int i = 1; i <= n; i++) {
	ll mx = 0;
	tadj[i].push_back(i);
	for (ll u : tadj[i]) mx = max(mx, u);
	adj[mx].push_back(i);
}
//c에 대해 내림차순 정렬
for (int i = 1; i <= n; i++) {
	sort(adj[i].begin(), adj[i].end(), [](int i1, int i2) {
		return c[i1] > c[i2];
		});
}

 

f(pos, sol) : pos번째 성에 대해 병사가 sol명 있을 때 얻을 수 있는 최대 점수라고 두자. 

 

sol>=a[pos]인 경우, 아무도 배치하지 않고 그냥 지나가는 경우와 배치하는 경우로 나눠서 그 최대를 구하면 된다.

 

배치하는 경우(맨 밑식), pos와 인접한 성에 c값이 큰 순서대로 방문하게 되는데, cnt명의 사람을 배치하는 경우를 식으로 쓴 것이다. 

 

이 점화식을 코드로 구현해보면 다음과 같다. (pos대신 i, sol대신 j, i대신 f 등등 변수명 주의)

(bottom-up)

for (int i = n; i >= 1; i--) {
	for (int j = 0; j <= 5000; j++) {
		dp[i][j] = -INF;
		if (j < a[i]) continue;
		ll sum = 0;
		dp[i][j] = max(dp[i][j],dp[i + 1][j + b[i]]);
		for (int f = 0; f < adj[i].size(); f++) {
			sum += c[adj[i][f]];
			if(j - f - 1 + b[i]>=0) dp[i][j] = max(dp[i][j], dp[i + 1][j - f - 1 + b[i]] + sum);
		}
	}
}

 

(top-down)

ll f(int pos, int sol) {
	if (pos == n + 1) {
		if (sol >= 0) return 0;
		else return -INF;
	}

	ll& rst = dp[pos][sol];
	if (rst != -1) return rst;

	if (sol < a[pos]) return rst = -INF;

	rst = f(pos + 1, sol + b[pos]);
	sol += b[pos];
	ll sum = 0;
	for (int i = 0; i < adj[pos].size(); i++) {
		sum += c[adj[pos][i]];
		rst = max(rst, f(pos + 1, sol - i - 1) + sum);
	}
	return rst;
}

 

 

<전체코드 - bottom up>

#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
typedef long long ll;
const int MAXN = 5e3 + 5;
const ll INF = 1e18 + 7;

ll a[MAXN], b[MAXN], c[MAXN], dp[MAXN][MAXN];
vector<ll> tadj[MAXN], adj[MAXN];

int main() {
	ios_base::sync_with_stdio(false);
	cin.tie(nullptr);

	int n, m, k; cin >> n >> m >> k;
	for (int i = 1; i <= n; i++)
		cin >> a[i] >> b[i] >> c[i];
	for (int i = 0,u,v; i < m; i++) {
		cin >> u >> v;
		tadj[v].push_back(u);
	}
	for (int i = 1; i <= n; i++) {
		ll mx = 0;
		tadj[i].push_back(i);
		for (ll u : tadj[i]) mx = max(mx, u);
		adj[mx].push_back(i);
	}

	for (int i = 1; i <= n; i++) {
		sort(adj[i].begin(), adj[i].end(), [](int i1, int i2) {
			return c[i1] > c[i2];
			});
	}
	
	for (int i = n; i >= 1; i--) {
		for (int j = 0; j <= 5000; j++) {
			dp[i][j] = -INF;
			if (j < a[i]) continue;
			ll sum = 0;
			dp[i][j] = max(dp[i][j],dp[i + 1][j + b[i]]);
			for (int f = 0; f < adj[i].size(); f++) {
				sum += c[adj[i][f]];
				if(j - f - 1 + b[i]>=0) dp[i][j] = max(dp[i][j], dp[i + 1][j - f - 1 + b[i]] + sum);
			}
		}
	}

	cout << (dp[1][k] < 0 ? -1 : dp[1][k]) << '\n';
}

 

<전체코드 - top-down>

#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
typedef long long ll;
const int MAXN = 5e3 + 5;
const ll INF = 1e18 + 7;

int n, m, k;
ll a[MAXN], b[MAXN], c[MAXN], dp[MAXN][MAXN];
vector<ll> tadj[MAXN], adj[MAXN];

ll f(int pos, int sol) {
	if (pos == n + 1) {
		if (sol >= 0) return 0;
		else return -INF;
	}

	ll& rst = dp[pos][sol];
	if (rst != -1) return rst;

	if (sol < a[pos]) return rst = -INF;

	rst = f(pos + 1, sol + b[pos]);
	sol += b[pos];
	ll sum = 0;
	for (int i = 0; i < adj[pos].size(); i++) {
		sum += c[adj[pos][i]];
		rst = max(rst, f(pos + 1, sol - i - 1) + sum);
	}
	return rst;
}

int main() {
	ios_base::sync_with_stdio(false);
	cin.tie(nullptr);

	cin >> n >> m >> k;
	for (int i = 1; i <= n; i++)
		cin >> a[i] >> b[i] >> c[i];
	for (int i = 0, u, v; i < m; i++) {
		cin >> u >> v;
		tadj[v].push_back(u);
	}
	for (int i = 1; i <= n; i++) {
		ll mx = 0;
		tadj[i].push_back(i);
		for (ll u : tadj[i]) mx = max(mx, u);
		adj[mx].push_back(i);
	}

	for (int i = 1; i <= n; i++) {
		sort(adj[i].begin(), adj[i].end(), [](int i1, int i2) {
			return c[i1] > c[i2];
			});
	}

	fill(&dp[0][0], &dp[MAXN - 1][MAXN], -1);
	ll ans = f(1, k);
	cout << (ans < 0 ? -1 : ans) << '\n';
}

 

 

<생각>

dp풀때마다 느끼는거지만, 점화식이 확실히 안나온 상태에서 코드에 손대면 너무 복잡해지고 오히려 시간이 더 드는것 같다.

 

adj를 돌면서 INF값이 나오는 순간 break해도 되지만 크게 봤을 때 전체 시간복잡도에 영향이 미미하므로 안해줘도 된다. 

 

시간복잡도는 O(n * 최대병사수)이다. adj를 도는 것을 생각하면 O(n*n 최대병사수) 아닌가? 할 수 있지만 adj를 이어줬을 때를 생각해보면 한 종점에 대해 이어진 간선은 단 하나이므로 그 영향이 가려진다. 

 

q개의 질의에 대해 입력되는 n의 최대값은 각각 1e5이고 , 그 모든 n의 값의 합도 최대 1e5라고 하면 전체 질의에 대한 시간복잡도는 q*n인 것과 비슷한 상황이라고 이해했다.