TopCoder SRM 711 Div2 Hard TreeMovingDiv2
問題概要
与えられた引数にしたがって, 頂点数が n の m 個の木を構築します.
それぞれの木を T(i) とします.
各 i = 0, 1, ..., m-1 について,辺 e(i) ∈ T(i) を一つ選びます.
e(i) を T(i) から取り除き,T(i+1) に加えます.T(m-1) のものは T(0) に加えます.
この操作を行った後,すべての T(i) がまだ木であるような操作の仕方は何通りありますか.
1e9+7 で割った余りを求めなさい.
・ 制約
2 <= n <= 50
2 <= m <= 50
解法
m, n の制約が緩いので,愚直な DP で通ります.
dp[i][j][k] := i 番目の木から j 番目の辺を選び,かつ 0 番目の木からは k 番目の辺を選んでいた場合の数
とします.
あとは,T(i) に対して,T(i-1) の j 番目の辺を選んだときに,T(i) で k 番目の木を選んだ場合,T(i) が木になっているかを Union-find 木で確認します.
木になっていれば,各 l に対して dp[i][k][l] += dp[i-1][j][l] とすればよいです.
最後に,0 番目の木に戻ってきたときは処理が特別になります.
dp[0][i][i] を一旦 0 に戻します.
同じように j, k を選んだ時に,木になっていれば,今度は各 l に対してではなく,k のみ dp[i][k][k] += dp[i-1][j][k]; とします.
こうしないと,最初に k 番目の辺を選んでいたのに,違う辺を選んでいた場合の数が含まれてしまうためです.
あとは,dp[0][i][i] の総和を取れば,答えになります.
計算量は O(mn^3 * (union-findの分)) です.
ソースコード
#include <iostream> #include <vector> #include <cmath> using namespace std; // // union_find の実装は本質でないので省略 // using ll = long long; using edge = pair<int, int>; constexpr ll mod = 1e9+7; class TreeMovingDiv2 { public: int count(int n, vector<int> roots, vector<int> a, vector<int> b, vector<int> c) { const int m = roots.size(); vector<vector<edge>> es(m, vector<edge>(n-1)); for(int i=0; i<m; ++i) { vector<int> x(n-1); x[0] = c[i]; for(int k=1; k<n-1; ++k) { x[k] = ((ll)a[i] * x[k-1] + b[i]) % mod; } for(int j=0; j<n-1; ++j) { es[i][j].first = ((ll)roots[i] + j + 1) % n; es[i][j].second = ((ll)roots[i] + (x[j] % (j+1))) % n; } } vector<vector<vector<ll>>> dp(m+1, vector<vector<ll>>(n-1, vector<ll>(n-1))); for(int i=0; i<n-1; ++i) { dp[0][i][i] = 1; } for(int i=0; i<m; ++i) { int next = (i + 1) % m; if(next == 0) { for(int j=0; j<n-1; ++j) { dp[0][j][j] = 0; } } for(int j=0; j<n-1; ++j) { int prev_v = es[i][j].first, prev_u = es[i][j].second; for(int k=0; k<n-1; ++k) { union_find uf(n); uf.unite(prev_v, prev_u); for(int l=0; l<n-1; ++l) { if(k == l) { continue; } uf.unite(es[next][l].first, es[next][l].second); } if(uf.size(prev_v) == n) { if(next == 0) { (dp[next][k][k] += dp[i][j][k]) %= mod; } else { for(int l=0; l<n-1; ++l) { (dp[next][k][l] += dp[i][j][l]) %= mod; } } } } } } ll res = 0; for(int i=0; i<n-1; ++i) { (res += dp[0][i][i]) %= mod; } return res; } };
感想
本番で通せそうで時間が足りなかった問題.全完したかった.
しかしこれ,入力に悪意があって,木の構築段階で int だとオーバーフローするような入力があるくせに,関数定義は vector(int) を強制されてます.嫌がらせかな???
Div 1 の 1 <= n <= 300 だと O(n^3) ってことなんだろうけど,方法が全然わからない.