TopCoder SRM 711 Div2 Hard TreeMovingDiv2

問題概要

与えられた引数にしたがって, 頂点数が n の m 個の木を構築します.
それぞれの木を T(i) とします.
各 i = 0, 1, ..., m-1 について,辺 e(i) ∈ T(i) を一つ選びます.
e(i) を T(i) から取り除き,T(i+1) に加えます.T(m-1) のものは T(0) に加えます.
この操作を行った後,すべての T(i) がまだ木であるような操作の仕方は何通りありますか.
1e9+7 で割った余りを求めなさい.

・ 制約
2 <= n <= 50
2 <= m <= 50

解法

m, n の制約が緩いので,愚直な DP で通ります.
dp[i][j][k] := i 番目の木から j 番目の辺を選び,かつ 0 番目の木からは k 番目の辺を選んでいた場合の数
とします.
あとは,T(i) に対して,T(i-1) の j 番目の辺を選んだときに,T(i) で k 番目の木を選んだ場合,T(i) が木になっているかを Union-find 木で確認します.
木になっていれば,各 l に対して dp[i][k][l] += dp[i-1][j][l] とすればよいです.

最後に,0 番目の木に戻ってきたときは処理が特別になります.
dp[0][i][i] を一旦 0 に戻します.
同じように j, k を選んだ時に,木になっていれば,今度は各 l に対してではなく,k のみ dp[i][k][k] += dp[i-1][j][k]; とします.
こうしないと,最初に k 番目の辺を選んでいたのに,違う辺を選んでいた場合の数が含まれてしまうためです.

あとは,dp[0][i][i] の総和を取れば,答えになります.

計算量は O(mn^3 * (union-findの分)) です.

ソースコード

#include <iostream>
#include <vector>
#include <cmath>
using namespace std;

//
// union_find の実装は本質でないので省略
//

using ll = long long;

using edge = pair<int, int>;

constexpr ll mod = 1e9+7;

class TreeMovingDiv2 {
public:
    int count(int n, vector<int> roots, vector<int> a, vector<int> b, vector<int> c) {
        const int m = roots.size();
        vector<vector<edge>> es(m, vector<edge>(n-1));
        for(int i=0; i<m; ++i) {
            vector<int> x(n-1);
            x[0] = c[i];
            for(int k=1; k<n-1; ++k) {
                x[k] = ((ll)a[i] * x[k-1] + b[i]) % mod;
            }
            for(int j=0; j<n-1; ++j) {
                es[i][j].first = ((ll)roots[i] + j + 1) % n;
                es[i][j].second = ((ll)roots[i] + (x[j] % (j+1))) % n;
            }
        }
        vector<vector<vector<ll>>> dp(m+1, vector<vector<ll>>(n-1, vector<ll>(n-1)));
        for(int i=0; i<n-1; ++i) {
            dp[0][i][i] = 1;
        }
        for(int i=0; i<m; ++i) {
            int next = (i + 1) % m;
            if(next == 0) {
                for(int j=0; j<n-1; ++j) {
                    dp[0][j][j] = 0;
                }
            }
            for(int j=0; j<n-1; ++j) {
                int prev_v = es[i][j].first, prev_u = es[i][j].second;
                for(int k=0; k<n-1; ++k) {
                    union_find uf(n);
                    uf.unite(prev_v, prev_u);
                    for(int l=0; l<n-1; ++l) {
                        if(k == l) {
                            continue;
                        }
                        uf.unite(es[next][l].first, es[next][l].second);
                    }
                    if(uf.size(prev_v) == n) {
                        if(next == 0) {
                            (dp[next][k][k] += dp[i][j][k]) %= mod;
                        } else {
                            for(int l=0; l<n-1; ++l) {
                                (dp[next][k][l] += dp[i][j][l]) %= mod;
                            }
                        }
                    }
                }
            }
        }
        ll res = 0;
        for(int i=0; i<n-1; ++i) {
            (res += dp[0][i][i]) %= mod;
        }
        return res;
    }
};

感想

本番で通せそうで時間が足りなかった問題.全完したかった.

しかしこれ,入力に悪意があって,木の構築段階で int だとオーバーフローするような入力があるくせに,関数定義は vector(int) を強制されてます.嫌がらせかな???

Div 1 の 1 <= n <= 300 だと O(n^3) ってことなんだろうけど,方法が全然わからない.