AOJ 2556 Integer in Integer

解法

やり方は色々あると思うんですが,僕は桁DPっぽくやりました(正直実装方針かなりミスっていると思う.)

まず,[0..N] までの答え f(N) を求めることができれば,f(B) - f(A - 1) で求められるので,f(N) を考えましょう.

dp[i][j][k] := i 桁目まで見たときに,C と先頭 j 桁が一致していて,その桁まで見たときに N 以下であるかが k であるような,leading zero を除いた場合の数
dp0[i][j][k] := i 桁目まで見たときに,C と先頭 j 桁が一致していて,その桁まで見たときに N 以下であるかが k であるような,leading zero である場合の数

dp2[i][j][k] := i 桁目まで見たときに,N 以下であるかが j であり,leading zero であるかが k であるようなもので,含まれる C の数

まず dp を計算する前に,C と先頭 j 桁が一致しているときに d (1桁の数)を付け加えたら何桁一致することになるかを match[j][d] として求めておきます(これは全探索で求められる).

match[j][d] をつかって各 dp テーブルを更新していくのですが,流れとしては

  • dp2 で,1桁付け加えていく
  • dp と dp0 を更新する
  • i 桁目で dp と dp0 で C と完全一致するパターンを dp2 に付け加える

という感じになります.コードはごちゃごちゃしてるけど遷移は自然.
dp と dp0 は,あくまで i 桁目までみて先頭が C と一致しているような数を求めているだけにすぎないです(i + 1 桁目以降は存在しないとして数えている).
(例.C = 11 なら,3桁目までみて 11x となるのは 10 通り,2桁目まででは 11 の 1 通り.x の値によって,dp に入るか dp2 に入るかが変わってくる)

答えは dp2[i][0][0] + (i != |N| ? dp2[i][1][0] : 0) の総和です.
C = 0 のときはちょっと注意が必要です.

書いてて自分でもよくわからなくなってきた(???)

オーダーは O(2 * 10 * |C| * |A|) の 10^8 ぐらいで実際はもうちょい遅いかな~ぐらいです.

ソースコード

#include <bits/stdc++.h>
using namespace std;

using ll = long long;

constexpr int M = 1e9 + 7;

int solve(string A, string C) {
    int const n = A.size();
    int const m = C.size();
    reverse(A.begin(), A.end());
    reverse(C.begin(), C.end());
    vector<vector<int>> match(m + 1, vector<int>(10));
    for(int d = 0; d <= 9; ++d) {
        char ch = '0' + d;
        for(int i = 0; i <= m; ++i) {
            int len = 0;
            for(int j = min(i + 1, m); j >= 1; --j) {
                if(C.substr(0, j) == C.substr(i - j + 1, j - 1) + ch) {
                    len = j;
                    break;
                }
            }
            match[i][d] = len;
        }
    }
    vector<vector<int>> dp(m + 1, vector<int>(2));
    vector<vector<int>> dp0(m + 1, vector<int>(2)); // leading zero
    vector<vector<vector<int>>> dp2(n + 1, vector<vector<int>>(2, vector<int>(2)));
    dp[0][0] = 1;
    int res = 0;
    for(int i = 0; i < n; ++i) {
        vector<vector<int>> nxt(m + 1, vector<int>(2)), nxt0(m + 1, vector<int>(2));
        for(int k = 0; k < 2; ++k) {
            for(int d = 0; d <= 9; ++d) {
                char ch = '0' + d;
                bool nk = k && ch == A[i] || ch > A[i];
                if(d == 0) {
                    (dp2[i + 1][nk][1] += dp2[i][k][0]) %= M;
                    (dp2[i + 1][nk][1] += dp2[i][k][1]) %= M;
                } else {
                    (dp2[i + 1][nk][0] += dp2[i][k][0]) %= M;
                    (dp2[i + 1][nk][0] += dp2[i][k][1]) %= M;
                }
            }
        }
        for(int j = 0; j <= m; ++j) {
            for(int k = 0; k < 2; ++k) {
                for(int d = 0; d <= 9; ++d) {
                    char ch = '0' + d;
                    bool nk = k && ch == A[i] || ch > A[i];
                    int nj = match[j][d];
                    if(d == 0) {
                        (nxt0[nj][nk] += dp[j][k]) %= M;
                        (nxt0[nj][nk] += dp0[j][k]) %= M;
                    } else {
                        (nxt[nj][nk] += dp[j][k]) %= M;
                        (nxt[nj][nk] += dp0[j][k]) %= M;
                    }
                }
            }
        }
        (dp2[i + 1][0][0] += nxt[m][0]) %= M;
        (dp2[i + 1][0][1] += nxt0[m][0]) %= M;
        (dp2[i + 1][1][0] += nxt[m][1]) %= M;
        (dp2[i + 1][1][1] += nxt0[m][1]) %= M;
        dp = move(nxt);
        dp0 = move(nxt0);
    }

    for(int i = 1; i <= n; ++i) {
        (res += dp2[i][0][0]) %= M;
        if(i != n) {
            (res += dp2[i][1][0]) %= M;
        }
    }
    if(C == "0") {
        (res += 1) % M;
    }
    return res;
}

string decrement(string A) {
    for(int i = A.size() - 1; i >= 0; --i) {
        if(A[i] > '0') {
            A[i]--;
            break;
        } else {
            A[i] = '9';
        }
    }
    if(A[0] == '0' && A.size() > 1) {
        return A.substr(1);
    } else {
        return A;
    }
}

int main() {
    string A, B, C;
    cin >> A >> B >> C;
    if(A != "0") {
        cout << (solve(B, C) - solve(decrement(A), C) + M) % M << endl;
    } else {
        cout << solve(B, C) << endl;
    }
}

感想

通ったからいいけど,桁DPっぽくやったのは失敗だったかなあと思っています.
今真っ白な状態から書くなら,j 桁目が C であるような N 以下の数っていうのを求めていく感じになりそう.[X] ... C ... [Y] の X と Y を N を越えないように数えるみたいな.