AOJ 2604 - Pattern Language
解法
変数は 10 個しかないので、それぞれの変数を何桁にするか最初に全探索する。
すると、変数の桁ごとに、どの桁と対応するかを決定できる。
対応する集合を union find 木かなにかで管理しておく。
次に面倒なのは2桁目の値を何にするかに応じて1桁目の上限が変わることだが、これも 2^10 通り全探索して間に合う。
すると、union find の各集合が取れる値の範囲が簡単に求まるので、あとは全部掛け合わせるだけ。
O(2^(2m) m) で解ける。
追記:冷静に考えたら 1桁 or 2桁でmaxじゃない or 2桁でmax の3通りで O(3^m m) でいいじゃん。何やってんだか。
ソースコード
#include <bits/stdc++.h> using namespace std; using ll = long long; using pii = pair<int, int>; constexpr int mod = 1e9 + 7; class union_find { public: union_find(int n) : par(n, -1) {} int root(int x) { return par[x] < 0 ? x : par[x] = root(par[x]); } void unite(int x, int y) { x = root(x), y = root(y); if(x == y) return; if(par[x] < par[y]) swap(x, y); par[x] += par[y]; par[y] = x; } private: vector<int> par; }; int main() { int n, m; cin >> n >> m; string s; cin >> s; vector<char> var(m); vector<int> u(m); for(int i = 0; i < m; ++i) { cin >> var[i] >> u[i]; } ll ans = 0; for(int S1 = 0; S1 < (1 << m); ++S1) { // determine 1 digit or 2 digits { bool check = true; for(int i = 0; i < m; ++i) { check &= !(S1 & (1 << i)) || (u[i] >= 10); } if(!check) continue; } vector<pii> vs; for(int i = 0; i < n; ++i) { if(isdigit(s[i])) { vs.emplace_back(-1, s[i] - '0'); } else { const int id = find(begin(var), end(var), s[i]) - begin(var); if(S1 & (1 << id)) { vs.emplace_back(id, 2); } vs.emplace_back(id, 1); } } // integrate auto ss = vs; sort(begin(ss), end(ss)); ss.erase(unique(begin(ss), end(ss)), end(ss)); const int vsz = vs.size(); const int sz = ss.size(); union_find uf(sz); for(int i = 0; i < vsz; ++i) { const int id1 = lower_bound(begin(ss), end(ss), vs[i]) - begin(ss); const int id2 = lower_bound(begin(ss), end(ss), vs[vsz - i - 1]) - begin(ss); uf.unite(id1, id2); } // determine 2nd digit for(int S2 = 0; S2 < (1 << m); ++S2) { { bool check = (S1 & S2) == S2; for(int i = 0; i < m; ++i) { check &= !(S2 & (1 << i)) || u[i] >= 10; } if(!check) continue; } vector<ll> lb(sz, 0), ub(sz, mod); for(int i = 0; i < sz; ++i) { const int id = uf.root(i); const int midx = ss[i].first; ub[id] = min(ub[id], 9LL); if(midx == -1) { lb[id] = max(lb[id], (ll)ss[i].second); ub[id] = min(ub[id], (ll)ss[i].second); } else if(ss[i].second == 2) { lb[id] = max(lb[id], 1LL); if(S2 & (1 << midx)) { lb[id] = max(lb[id], (ll)(u[midx] / 10)); ub[id] = min(ub[id], (ll)(u[midx] / 10)); } else { ub[id] = min(ub[id], (ll)(u[midx] / 10 - 1)); } } else { if(S2 & (1 << midx) || u[midx] < 10) { ub[id] = min(ub[id], (ll)(u[midx] % 10)); } } } ll tans = 1; for(int i = 0; i < sz; ++i) { if(ub[i] == mod) continue; (tans *= max(0LL, ub[i] - lb[i] + 1)) %= mod; } (ans += tans) %= mod; } } cout << ans << endl; }
感想
i やら id やら midx やらわかりにくい実装になっちゃった。反省。