부모 자식 관계가 없는 트리 노드들의 집합들 중에서 최대 크기를 찾는 문제입니다.
집합의 노드들이 간선으로 이어져 있으면 안 되고 집합에 있는 정점의 가중치 합의 최댓값을 구하는 문제입니다.
트리 노드들을 탐색하며 부모 노드가 포함되어 있으면 넘어가고 부모 노드가 포함되어 있지 않으면 포함하거나 넘어갈 수 있습니다.
이렇게 해서 모든 경우를 봐서 최댓값을 구하면 되겠지만 N이 최대 10000이므로 O(2^N)의 브루트 포스를 하면 안 됩니다.
시간을 줄여야 하는데… DP로 하면 될 것 같네요.
DP[노드 번호][노드 포함 여부] = 최대 가중치 합
으로 하면 트리를 한번 탐색하며 가중치 합을 저장하면 됩니다.
now = 현재 노드, next = 자식 노드
예를 들어 DP[now]를 구하겠습니다.
now 노드가 포함되지 않았을 때 next 노드를 포함해도 되고 하지 않아도 됩니다.
그러므로 DP[now][0]에 next가 포함될 경우와 포함되지 않은 경우를 비교하여 최댓값을 더해주면 됩니다.
DP[now][0] += max(DP[next][0], DP[next][1])
now 노드가 포함되었을 때는 next 노드가 포함되면 안 되므로
DP[now][1] += DP[next][0]
대입하는 것이 아니라 더해주는 이유는 가중치의 합을 구하는 문제이므로 DP[next]들의 합을 구해야 합니다.
이제 포함된 노드들을 정렬해서 출력을 해줘야 하는데 DP에 저장된 값들로 역추적을 하겠습니다.
DP는 다른 문제와 비슷한 문제인데 역추적이 있네요… 역추적 싫은데…
아까 탐색을 시작했던 노드부터 역추적을 시작하겠습니다.
현재 노드를 포함했을 때 DP 값이 포함하지 않았을 때 DP 값보다 크다면 그 노드를 포함해야 최댓값을 구할 수 있습니다.
DP[now][0] > DP[now][1]이라면 now를 포함해야 최댓값이므로 now는 출력을 해야 합니다.
그리고 현재 노드를 추가했다면 다음 자식 노드는 집합에 추가하면 안 되므로 집합에 부모 노드 포함 여부를 저장하는 변수를 선언합니다.
부모가 포함되었을 때는 현재 노드는 포함하면 안 되므로 다음으로 넘어갑니다.
이 과정으로 모든 노드를 탐색하면 집합에 포함된 정점의 노드를 구할 수 있고 정렬하여 출력하면 됩니다.
간단하죠?
위에서는 이차원 배열이 이해도 쉽고 설명도 편하게 하려고
이차원 배열로 DP를 저장했지만 저는 시간을 조금이나마 줄여볼까 하고 일차원 배열에 pair를 썼습니다.
편하신 방법으로 코딩하세요!!
#include <bits/stdc++.h>
#define FIO ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
#define FOR(i,a,b) for(int i=a;i<=b;i++)
#define vi vector <int>
#define vvi vector <vi>
#define pii pair<int,int>
#define all(x) (x).begin(), (x).end()
#define fs first
#define sd second
#define eb emplace_back
using namespace std;
int main() {
FIO;
int n; cin >> n;
vi v(n + 1), path;
vvi adj(n + 1);
vector <pii> dp(n + 1); // first = 노드 포함 x, second = 노드 포함 o
FOR(i, 1, n) cin >> v[i];
FOR(i, 1, n - 1) {
int a, b; cin >> a >> b;
adj[a].eb(b);
adj[b].eb(a);
}
function <pii(int, int)> DP = [&](int now, int par) {
pii& ret = dp[now] = { 0,v[now] };
for (int next : adj[now]) {
if (next == par) continue;
pii val = DP(next, now);
ret.fs += max(val.fs, val.sd);
ret.sd += val.fs;
}
return ret;
};
function <void(int,int,int)> trace = [&](int now, int par, int chk){ // 역추적
if (!chk && dp[now].fs < dp[now].sd) path.eb(now), chk = 1;
else chk = 0;
for (int next : adj[now]) {
if (next == par) continue;
trace(next, now, chk);
}
};
pii ans = DP(1, 0);
cout << max(ans.fs, ans.sd) << '\n';
trace(1, 0, 0);
sort(all(path));
for (int i : path) cout << i << ' ';
return 0;
}