Codeforces Round #429 (Div. 2) E. On the Bench

問題概要

数列 {a_n} が与えられる.
任意の 1 <= i < n について a[i] * a[i + 1] が平方数にならないような,数列{a_n} の並び替え方は何通りあるか.
並び替えた結果が同じものでも,並び替え方が違えば別として数えるものとする.

・制約
1 <= n <= 300
1 <= a[i] <= 10^9

解法

まず,掛け合わせることで平方数になってしまうという関係は,同値関係になります.実際,a * b が平方数になる,を aRb とかくと,
aRa
bRa <-> aRb
aRb, bRc -> aRc
が成り立ちます.3番目については,(a * b) * (b * c) == (a * c) * b^2 で,左辺が平方数だから (a * c) も平方数となります.

よって,平方数ができないような並び替え方というのは,まずこれらの同値類をどこに配置するかを考え,具体的に同値類の中でどう並べるか(同じ数でも区別するので,これは単なる順列の問題)になります.

同値類の配置の仕方については,TDPCの以下の問題と全く同じなので省略します.
O: 文字列 - Typical DP Contest | AtCoder

ソースコード

#include <bits/stdc++.h>
using namespace std;

using ll = long long;

constexpr ll M = 1e9 + 7;

bool is_square(ll x) {
    ll t = sqrt(x);
    return t * t == x;
}
 
int main() {
    int n;
    cin >> n;
    map<ll, ll> cnt;
    for(int i = 0; i < n; ++i) {
        ll a;
        cin >> a;
        cnt[a]++;
    }
    // 同値類にまとめる.
    map<ll, vector<int>> num;
    for(auto& a : cnt) {
        bool update = false;
        for(auto& p : num) {
            if(is_square(a.first * p.first)) {
                num[p.first].push_back(a.second);
                update = true;
                break;
            }
        }
        if(!update) {
            num[a.first].push_back(a.second);
        }
    }

    vector<ll> fact(301);
    fact[0] = 1;
    for(int i = 1; i <= 300; ++i) {
        fact[i] = (fact[i - 1] * i) % M;
    }
    vector<vector<ll>> comb(301, vector<ll>(301));
    for(int i = 0; i < 301; ++i) {
        for(int j = 0; j <= i; ++j) {
            if(i == j || j == 0) {
                comb[i][j] = 1;
            } else {
                comb[i][j] = (comb[i - 1][j - 1] + comb[i - 1][j]) % M;
            }
        }
    }

    vector<vector<ll>> dp(num.size() + 1, vector<ll>(n + 1));
    dp[0][0] = 1;
    ll res = 1;
    ll sum = 0;
    int i = 0;
    for(auto& p : num) {
        int m = accumulate(p.second.begin(), p.second.end(), 0);
        // 同値類の中でどう配置するか
        res = (res * fact[m]) % M;
        // 同値類の中身を無視した配置の仕方
        for(int j = 0; j <= sum; ++j) {
            for(int k = 1; k <= m; ++k) {
                for(int l = 0; l <= min(j, k); ++l) {
                    ll t = comb[m - 1][k - 1];
                    t = (t * comb[j][l]) % M;
                    t = (t * comb[sum - j + 1][k - l]) % M;
                    t = (t * dp[i][j]) % M;
                    (dp[i + 1][j - l + m - k] += t) %= M;
                }
            }
        }

        sum += m;
        i++;
    }
    res = (res * dp[num.size()][0]) % M;

    cout << res << endl;
}

感想

本番で通せなかった問題.TDPCをやっておけばよかったと後悔しました.