AOJ 2605 Social Monsters
解法
制約をよく読むと,(使えない辺を含める)与えられたグラフにおける各連結成分は,1つのパスまたは閉路になっていることがわかります.
各連結成分において計算して,最後にまとめあげることにします.
1つのパスについては,以下のDPで端から順番に計算します.
dp[i][j][k] := i 番目の頂点までみたとき,それまでに j 個の頂点を用いていて,直前の頂点を使ったかのフラグが k であるときの最大値
隣接してる2つが使えないなら,k == 0 となる場合しか今見ている頂点は使えない,みたいな遷移になります.
閉路については,追加で情報が必要です.閉路の始点は適当に決めて良いです.
dp[i][j][k][l] := i 番目の頂点まで見た時,それまでに j 個の頂点を用いていて,直前の頂点を使ったかのフラグが k であり,最初の頂点を使ったかのフラグが l であるときの最大値
最後に各連結成分でもとめた dp の値を合わせて全体の結果を得ます.
ソースコード
#include <bits/stdc++.h> using namespace std; constexpr int INF = 1e9; struct edge { int to, cost; }; using edges = vector<edge>; using graph = vector<edges>; int N, M, K; int cost[2001][2001]; // true -> circuit bool dfs(graph const& g, int v, int p, vector<bool>& visited, vector<int>& res) { res.push_back(v); visited[v] = true; bool circuit = false; for(auto& e : g[v]) { if(e.to == p) { continue; } if(visited[e.to]) { circuit = true; } else { circuit |= dfs(g, e.to, v, visited, res); } } return circuit; } vector<int> calc_line(vector<int> const& line) { vector<vector<int>> dp(K + 1, vector<int>(2, -INF)); dp[0][0] = dp[1][1] = 0; for(int i = 1; i < (int)line.size(); ++i) { int pre = line[i - 1]; int v = line[i]; vector<vector<int>> nxt(K + 1, vector<int>(2, -INF)); for(int k = 0; k <= K; ++k) { if(dp[k][0] != -INF) { if(k + 1 <= K) { nxt[k + 1][1] = max(nxt[k + 1][1], dp[k][0]); } nxt[k][0] = max(nxt[k][0], dp[k][0]); } if(dp[k][1] != -INF) { if(cost[pre][v] != 0 && k + 1 <= K) { nxt[k + 1][1] = max(nxt[k + 1][1], dp[k][1] + cost[pre][v]); } nxt[k][0] = max(nxt[k][0], dp[k][1]); } } dp.swap(nxt); } vector<int> res(K + 1); for(int i = 0; i < K + 1; ++i) { res[i] = max(dp[i][0], dp[i][1]); } return res; } vector<int> calc_cycle(vector<int> const& cycle) { // (use, pre, first) vector<vector<vector<int>>> dp(K + 1, vector<vector<int>>(2, vector<int>(2, -INF))); dp[0][0][0] = dp[1][1][1] = 0; int const first = cycle[0]; for(int i = 1; i < (int)cycle.size(); ++i) { int pre = cycle[i - 1]; int v = cycle[i]; vector<vector<vector<int>>> nxt(K + 1, vector<vector<int>>(2, vector<int>(2, -INF))); for(int k = 0; k <= K; ++k) { for(int fuse = 0; fuse <= 1; ++fuse) { if(i == cycle.size() - 1 && fuse == 1) { if(dp[k][0][fuse] != -INF) { if(k + 1 <= K && cost[first][v] != 0) { nxt[k + 1][1][fuse] = max(nxt[k + 1][1][fuse], dp[k][0][fuse] + cost[first][v]); } nxt[k][0][fuse] = max(nxt[k][0][fuse], dp[k][0][fuse]); } if(dp[k][1][fuse] != -INF) { if(cost[first][v] != 0 && cost[pre][v] != 0 && k + 1 <= K) { nxt[k + 1][1][fuse] = max(nxt[k + 1][1][fuse], dp[k][1][fuse] + cost[pre][v] + cost[first][v]); } nxt[k][0][fuse] = max(nxt[k][0][fuse], dp[k][1][fuse]); } } else { if(dp[k][0][fuse] != -INF) { if(k + 1 <= K) { nxt[k + 1][1][fuse] = max(nxt[k + 1][1][fuse], dp[k][0][fuse]); } nxt[k][0][fuse] = max(nxt[k][0][fuse], dp[k][0][fuse]); } if(dp[k][1][fuse] != -INF) { if(cost[pre][v] != 0 && k + 1 <= K) { nxt[k + 1][1][fuse] = max(nxt[k + 1][1][fuse], dp[k][1][fuse] + cost[pre][v]); } nxt[k][0][fuse] = max(nxt[k][0][fuse], dp[k][1][fuse]); } } } } dp.swap(nxt); } vector<int> res(K + 1); for(int i = 0; i < K + 1; ++i) { res[i] = max({dp[i][0][0], dp[i][0][1], dp[i][1][0], dp[i][1][1]}); } return res; } int main() { cin >> N >> M >> K; graph g(N); for(int i = 0; i < M; ++i) { int a, b, c; cin >> a >> b >> c; a--; b--; g[a].push_back(edge{b, c}); g[b].push_back(edge{a, c}); cost[a][b] = cost[b][a] = c; } vector<vector<int>> line, cycle; vector<bool> visited(N); vector<int> sz; for(int i = 0; i < N; ++i) { if(!visited[i]) { vector<int> c; bool circuit = dfs(g, i, -1, visited, c); if(!circuit) { for(auto v : c) { visited[v] = false; } for(auto v : c) { if(g[v].size() == 1) { c.clear(); dfs(g, v, -1, visited, c); break; } } sz.push_back(c.size()); line.push_back(move(c)); } else { cycle.push_back(move(c)); } } } for(auto& c : cycle) { sz.push_back(c.size()); } vector<vector<int>> dp2; for(auto& l : line) { dp2.push_back(calc_line(l)); } for(auto& c : cycle) { dp2.push_back(calc_cycle(c)); } vector<int> dp(K + 1, -INF); dp[0] = 0; int total_sz = 0; for(int i = 0; i < (int)dp2.size(); ++i) { vector<int> nxt(K + 1, -INF); for(int k1 = 0; k1 <= total_sz; ++k1) { if(dp[k1] == -INF) { continue; } for(int k2 = 0; k2 <= sz[i] && k1 + k2 <= K; ++k2) { if(dp2[i][k2] != -INF) { nxt[k1 + k2] = max(nxt[k1 + k2], dp[k1] + dp2[i][k2]); } } } total_sz += sz[i]; dp.swap(nxt); } if(dp[K] == -INF) { cout << "Impossible" << endl; } else { cout << dp[K] << endl; } }
感想
実装が重い.うまくやれば実装量は減りそうだが,半分ぐらいは同じことをしてるコードなので,コード長の割には楽だったかもしれない.