Codeforces Round #433 (Div. 2) D. Jury Meeting

問題概要

街が n + 1 あり,それぞれナンバリングされている.また人が n 人いて,人 i は街 i に存在する.0番目の街は首都であり,人はいない.
ここで,飛行機のスケジュールが m 個与えられる.スケジュールは以下のようになっている.
d[i] 日に出発してその日のうちに到着し,街 f[i] から 街 t[i] に行くが,これにはコストが c[i] かかる.また,f[i] == 0 または t[i] == 0 が必ず成り立つ.(どちらかは必ず首都である)
この時,n 人すべてが首都に同時に k 日間いるような期間が存在し,かつ自分の街まで帰るような飛行機の利用の仕方で,コストを最小化せよ.

・制約
1 <= n <= 10^5
0 <= m <= 10^5
1 <= k <= 10^6
1 <= d[i] <= 10^6
0 <= f[i], t[i] <= 10^6
1 <= c[i] <= 10^6

解法

最初に d[i] でスケジュールをソートしておきます.

まず,コストを無視して全員が最短で首都につくタイミングを lb,全員が自分の街まで帰ることができるギリギリを ub とします.(後者は,スケジュールを逆順に見れば良いです.)

lb と ub の日の差が k + 1 より小さければ,同時に k 日間滞在できません.
そうでない場合は,幅が k + 1 より小さくならないように,lb と ub をいい感じに伸ばしてやって,その中で一番良いものを選ぶことになります.

これはちょうどしゃくとりっぽい発想なので,それで実装します.
go[i] := i 番目の人が行きで使えるコストの集合 (multiset)
ret[i] := i 番目の人が帰りで使えるコストの集合 (multiset)
とし,伸ばすときは insert して,縮めるときは erase します.
値の更新は begin() の値を足し引きするだけなので,難しくはないです.

計算量は O(mlogm + n) .

ソースコード

#include <bits/stdc++.h>
using namespace std;

using ll = long long;
using pii = pair<int, int>;
using pll = pair<ll, ll>;

constexpr int INF = 1e9;
constexpr long double eps = 1e-8;

struct flight {
    int d, f, t, c;

    bool operator<(flight const& f) const {
        return d < f.d;
    }
};

int main() {
    int n, m, k;
    cin >> n >> m >> k;
    vector<flight> fs(m);
    for(int i = 0; i < m; ++i) {
        scanf("%d %d %d %d", &fs[i].d, &fs[i].f, &fs[i].t, &fs[i].c);
    }
    sort(begin(fs), end(fs));

    vector<multiset<ll>> go(n + 1), ret(n + 1);
    set<int> gs, rs;
    int lb = -1, ub = -1;
    for(int i = 0; i < m; ++i) {
        if(fs[i].t == 0) {
            go[fs[i].f].insert(fs[i].c);
            gs.insert(fs[i].f);
        }
        if(gs.size() == n) {
            lb = i;
            break;
        }
    }
    for(int i = m - 1; i >= 0; --i) {
        if(fs[i].f == 0) {
            ret[fs[i].t].insert(fs[i].c);
            rs.insert(fs[i].t);
        }
        if(rs.size() == n) {
            ub = i;
            break;
        }
    }

    if(lb == -1 || ub == -1 || fs[ub].d - fs[lb].d < k + 1) {
        cout << -1 << endl;
        return 0;
    }

    int l = lb, r = ub;
    while(fs[ub].d - fs[l + 1].d >= k + 1) {
        if(fs[l + 1].t == 0) {
            go[fs[l + 1].f].insert(fs[l + 1].c);
        }
        l++;
    }

    ll t = 0;
    for(int i = 1; i <= n; ++i) {
        t += *go[i].begin() + *ret[i].begin();
    }

    ll res = t;
    while(fs[r - 1].d - fs[lb].d >= k + 1) {
        r--;
        while(fs[r].d - fs[l].d < k + 1) {
            if(fs[l].t == 0) {
                t -= *go[fs[l].f].begin();
                go[fs[l].f].erase(fs[l].c);
                t += *go[fs[l].f].begin();
            }
            l--;
        }
        if(fs[r].f == 0) {
            t -= *ret[fs[r].t].begin();
            ret[fs[r].t].insert(fs[r].c);
            t += *ret[fs[r].t].begin();
        }
        res = min(res, t);
    }

    cout << res << endl;
}