AOJ 2556 Integer in Integer
解法
やり方は色々あると思うんですが,僕は桁DPっぽくやりました(正直実装方針かなりミスっていると思う.)
まず,[0..N] までの答え f(N) を求めることができれば,f(B) - f(A - 1) で求められるので,f(N) を考えましょう.
dp[i][j][k] := i 桁目まで見たときに,C と先頭 j 桁が一致していて,その桁まで見たときに N 以下であるかが k であるような,leading zero を除いた場合の数
dp0[i][j][k] := i 桁目まで見たときに,C と先頭 j 桁が一致していて,その桁まで見たときに N 以下であるかが k であるような,leading zero である場合の数
dp2[i][j][k] := i 桁目まで見たときに,N 以下であるかが j であり,leading zero であるかが k であるようなもので,含まれる C の数
まず dp を計算する前に,C と先頭 j 桁が一致しているときに d (1桁の数)を付け加えたら何桁一致することになるかを match[j][d] として求めておきます(これは全探索で求められる).
match[j][d] をつかって各 dp テーブルを更新していくのですが,流れとしては
という感じになります.コードはごちゃごちゃしてるけど遷移は自然.
dp と dp0 は,あくまで i 桁目までみて先頭が C と一致しているような数を求めているだけにすぎないです(i + 1 桁目以降は存在しないとして数えている).
(例.C = 11 なら,3桁目までみて 11x となるのは 10 通り,2桁目まででは 11 の 1 通り.x の値によって,dp に入るか dp2 に入るかが変わってくる)
答えは dp2[i][0][0] + (i != |N| ? dp2[i][1][0] : 0) の総和です.
C = 0 のときはちょっと注意が必要です.
書いてて自分でもよくわからなくなってきた(???)
オーダーは O(2 * 10 * |C| * |A|) の 10^8 ぐらいで実際はもうちょい遅いかな~ぐらいです.
ソースコード
#include <bits/stdc++.h> using namespace std; using ll = long long; constexpr int M = 1e9 + 7; int solve(string A, string C) { int const n = A.size(); int const m = C.size(); reverse(A.begin(), A.end()); reverse(C.begin(), C.end()); vector<vector<int>> match(m + 1, vector<int>(10)); for(int d = 0; d <= 9; ++d) { char ch = '0' + d; for(int i = 0; i <= m; ++i) { int len = 0; for(int j = min(i + 1, m); j >= 1; --j) { if(C.substr(0, j) == C.substr(i - j + 1, j - 1) + ch) { len = j; break; } } match[i][d] = len; } } vector<vector<int>> dp(m + 1, vector<int>(2)); vector<vector<int>> dp0(m + 1, vector<int>(2)); // leading zero vector<vector<vector<int>>> dp2(n + 1, vector<vector<int>>(2, vector<int>(2))); dp[0][0] = 1; int res = 0; for(int i = 0; i < n; ++i) { vector<vector<int>> nxt(m + 1, vector<int>(2)), nxt0(m + 1, vector<int>(2)); for(int k = 0; k < 2; ++k) { for(int d = 0; d <= 9; ++d) { char ch = '0' + d; bool nk = k && ch == A[i] || ch > A[i]; if(d == 0) { (dp2[i + 1][nk][1] += dp2[i][k][0]) %= M; (dp2[i + 1][nk][1] += dp2[i][k][1]) %= M; } else { (dp2[i + 1][nk][0] += dp2[i][k][0]) %= M; (dp2[i + 1][nk][0] += dp2[i][k][1]) %= M; } } } for(int j = 0; j <= m; ++j) { for(int k = 0; k < 2; ++k) { for(int d = 0; d <= 9; ++d) { char ch = '0' + d; bool nk = k && ch == A[i] || ch > A[i]; int nj = match[j][d]; if(d == 0) { (nxt0[nj][nk] += dp[j][k]) %= M; (nxt0[nj][nk] += dp0[j][k]) %= M; } else { (nxt[nj][nk] += dp[j][k]) %= M; (nxt[nj][nk] += dp0[j][k]) %= M; } } } } (dp2[i + 1][0][0] += nxt[m][0]) %= M; (dp2[i + 1][0][1] += nxt0[m][0]) %= M; (dp2[i + 1][1][0] += nxt[m][1]) %= M; (dp2[i + 1][1][1] += nxt0[m][1]) %= M; dp = move(nxt); dp0 = move(nxt0); } for(int i = 1; i <= n; ++i) { (res += dp2[i][0][0]) %= M; if(i != n) { (res += dp2[i][1][0]) %= M; } } if(C == "0") { (res += 1) % M; } return res; } string decrement(string A) { for(int i = A.size() - 1; i >= 0; --i) { if(A[i] > '0') { A[i]--; break; } else { A[i] = '9'; } } if(A[0] == '0' && A.size() > 1) { return A.substr(1); } else { return A; } } int main() { string A, B, C; cin >> A >> B >> C; if(A != "0") { cout << (solve(B, C) - solve(decrement(A), C) + M) % M << endl; } else { cout << solve(B, C) << endl; } }
感想
通ったからいいけど,桁DPっぽくやったのは失敗だったかなあと思っています.
今真っ白な状態から書くなら,j 桁目が C であるような N 以下の数っていうのを求めていく感じになりそう.[X] ... C ... [Y] の X と Y を N を越えないように数えるみたいな.