AOJ 2255 6/2(1+2)

解法

BNFどおりに実装する必要はなくて,区間DPをすればよい.
その区間で取りうる値の集合を持たせておく.
区間 [l..r] の演算子を全て見ていって [l..i-1] と [i+1..r] に分ける.
この時,演算子で分けて括弧の対応が崩れるなら一旦スルーしておく必要がある.
たとえば,括弧が ( .. ( .. op_i .. ) .. ) となっていると,op_i で左右に分けることができない.
これは最初に括弧の対応付を調べておけばよい.
また,( expr ) となっているときは, expr に置き換えて処理するとよい.

ソースコード

#include <bits/stdc++.h>
using namespace std;

set<int> memo[201][201];
vector<int> par_pos;

int number(string const& s, int p) {
    int res = 0;
    while(p < s.size() && isdigit(s[p])) {
        res *= 10;
        res += s[p++] - '0';
    }
    return res;
}

set<int> solve(string const& s, int l, int r) {
    auto& res = memo[l][r];
    if(res.size() != 0) {
        return res;
    }
    bool f = true;
    for(int i=l; i<=r; ++i) {
        f &= isdigit(s[i]);
    }
    if(f) {
        res.insert(number(s, l));
        return res;
    }
    if(par_pos[l] == r) {
        return res = solve(s, l+1, r-1);
    }
    vector<int> pare(s.size());
    for(int i=l; i<=r; ++i) {
        if(s[i] == '(') {
            pare[i] = 1;
        }
        if(s[i] == ')') {
            pare[i] = -1;
        }
        if(i != l) {
            pare[i] += pare[i-1];
        }
    }
    for(int i=l+1; i<r; ++i) {
        char c = s[i];
        if(pare[i] != 0) {
            continue;
        }
        if(c == '+' || c == '-' || c == '*' || c == '/') {
            set<int> s1 = solve(s, l, i-1);
            set<int> s2 = solve(s, i+1, r);
            for(auto x : s1) {
                for(auto y : s2) {
                    if(c == '+') {
                        res.insert(x + y);
                    } else if(c == '-') {
                        res.insert(x - y);
                    } else if(c == '*') {
                        res.insert(x * y);
                    } else {
                        if(y == 0) {
                            continue;
                        }
                        res.insert(x / y);
                    }
                }
            }
        }
    }
    return res;
}

int main() {
    string s;
    while(cin >> s, s != "#") {
        for(int i=0; i<201; ++i) {
            for(int j=0; j<201; ++j) {
                memo[i][j].clear();
            }
        }
        par_pos.assign(s.size(), -1);
        stack<int> p;
        for(int i=0; i<s.size(); ++i) {
            if(s[i] == '(') {
                p.push(i);
            }
            if(s[i] == ')') {
                par_pos[p.top()] = i;
                p.pop();
            }
        }
        cout << solve(s, 0, s.size()-1).size() << endl;
    }
}