近期做的一些Atcoder题

大四老年人完全没有智力,于是做一些Atcoder题来预防老年痴呆。

1. AtCoder Beginner Contest 269 Ex

题意:Link

做法:考虑每个点的生成函数:

G_{u}(x)=\sum_{i} g_{ui} x^i

其中,g_{ui}表示u号点的子树中选大小为i的good vertex set的方案数。考虑卷积的组合意义,有:

G_{u}(x)=x+\prod_{v\in son[u]}{G_{v}(x)}

需要计算的就是G_{root}.
直接暴力卷积,复杂度不太行,考虑启发式合并,或者叫HLD(轻重链剖分)。
轻重链剖分相当于是每条重链单独处理,然后“跳”到另一条重链上。考虑当前处理的重链为\{v_{i_1},v_{i_2},\cdots,v_{i_t}\}(按深度从浅至深)。那么该重链最浅的点对其父亲贡献的生成函数为:

\begin{aligned}
G_{v_{top}}(x)=&(((G_{v_{i_t}}(x)+x)\\
&\times G_{v_{i_{t-1}}}(x)+x)\\
&\times\cdots)\times G_{v_{i_1}}(x) +x
\end{aligned}

这样暴力算,复杂度还是不太对,此时考虑分治。考虑仿射函数\mathcal{F_{u}}(\Box)=G_{u}(x)\times\Box +x,那么改写贡献:

G_{v_{top}}(x)=\mathcal{F_{i_1}}\mathcal{F_{i_1}}\cdots\mathcal{F_{i_t}}(1)

维护A\times\Box +B(A,B)两个多项式,直接分治NTT即可。
时间复杂度:\mathcal{O}(n\log^3n)
代码如下:

Code
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;

template <class T>
T power(T a, int b) {
    T res = 1;
    for (; b; b /= 2, a *= a) {
        if (b % 2) {
            res *= a;
        }
    }
    return res;
}
template <int mod>
struct mint {
    int x;
    mint() : x(0) {}
    mint(int64_t y) : x(y >= 0 ? y % mod : (mod - (-y) % mod) % mod) {}
    mint &operator+=(const mint &p) {
        if ((x += p.x) >= mod) x -= mod;
        return *this;
    }
    mint &operator-=(const mint &p) {
        if ((x += mod - p.x) >= mod) x -= mod;
        return *this;
    }
    mint &operator*=(const mint &p) {
        x = (int)(1LL * x * p.x % mod);
        return *this;
    }
    mint &operator/=(const mint &p) {
        *this *= p.inverse();
        return *this;
    }
    mint operator-() const { return mint(-x); }
    mint operator+(const mint &p) const { return mint(*this) += p; }
    mint operator-(const mint &p) const { return mint(*this) -= p; }
    mint operator*(const mint &p) const { return mint(*this) *= p; }
    mint operator/(const mint &p) const { return mint(*this) /= p; }
    bool operator==(const mint &p) const { return x == p.x; }
    bool operator!=(const mint &p) const { return x != p.x; }
    mint inverse() const {
        int a = x, b = mod, u = 1, v = 0, t;
        while (b > 0) {
            t = a / b;
            swap(a -= t * b, b);
            swap(u -= t * v, v);
        }
        return mint(u);
    }
    friend ostream &operator<<(ostream &os, const mint &p) { return os << p.x; }
    friend istream &operator>>(istream &is, mint &a) {
        int64_t t;
        is >> t;
        a = mint<mod>(t);
        return (is);
    }
    int get() const { return x; }
    static constexpr int get_mod() { return mod; }
};

template <class Z, int rt>
struct NTT {
    vector<int> rev;
    vector<Z> roots{0, 1};
    void dft(vector<Z> &a) {
        int n = a.size();

        if (int(rev.size()) != n) {
            int k = __builtin_ctz(n) - 1;
            rev.resize(n);
            for (int i = 0; i < n; i++) {
                rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
            }
        }

        for (int i = 0; i < n; i++) {
            if (rev[i] < i) {
                swap(a[i], a[rev[i]]);
            }
        }
        if (int(roots.size()) < n) {
            int k = __builtin_ctz(roots.size());
            roots.resize(n);
            while ((1 << k) < n) {
                Z e = power(Z(rt), (Z::get_mod() - 1) >> (k + 1));
                for (int i = 1 << (k - 1); i < (1 << k); i++) {
                    roots[2 * i] = roots[i];
                    roots[2 * i + 1] = roots[i] * e;
                }
                k++;
            }
        }
        for (int k = 1; k < n; k *= 2) {
            for (int i = 0; i < n; i += 2 * k) {
                for (int j = 0; j < k; j++) {
                    Z u = a[i + j];
                    Z v = a[i + j + k] * roots[k + j];
                    a[i + j] = u + v;
                    a[i + j + k] = u - v;
                }
            }
        }
    }
    void idft(vector<Z> &a) {
        int n = a.size();
        reverse(a.begin() + 1, a.end());
        dft(a);
        Z inv = (1 - Z::get_mod()) / n;
        for (int i = 0; i < n; i++) {
            a[i] *= inv;
        }
    }
    vector<Z> multiply(vector<Z> a, vector<Z> b) {
        int sz = 1, tot = a.size() + b.size() - 1;
        while (sz < tot) {
            sz *= 2;
        }

        a.resize(sz), b.resize(sz);
        dft(a), dft(b);

        for (int i = 0; i < sz; ++i) {
            a[i] = a[i] * b[i];
        }

        idft(a);
        a.resize(tot);
        return a;
    }
};

template <class Z, int rt>
struct Poly {
    vector<Z> a;
    Poly() {}
    Poly(int sz, Z val) { a.assign(sz, val); }
    Poly(const vector<Z> &a) : a(a) {}
    Poly(const initializer_list<Z> &a) : a(a) {}
    int size() const { return a.size(); }
    void resize(int n) { a.resize(n); }
    Z operator[](int idx) const {
        if (idx < size()) {
            return a[idx];
        } else {
            return 0;
        }
    }
    Z &operator[](int idx) { return a[idx]; }
    Poly mulxk(int k) const {
        auto b = a;
        b.insert(b.begin(), k, 0);
        return Poly(b);
    }
    Poly modxk(int k) const {
        k = min(k, size());
        return Poly(vector<Z>(a.begin(), a.begin() + k));
    }
    Poly divxk(int k) const {
        if (size() <= k) {
            return Poly();
        }
        return Poly(vector<Z>(a.begin() + k, a.end()));
    }
    friend Poly operator+(const Poly &a, const Poly &b) {
        vector<Z> res(max(a.size(), b.size()));
        for (int i = 0; i < int(res.size()); i++) {
            res[i] = a[i] + b[i];
        }
        return Poly(res);
    }
    friend Poly operator-(const Poly &a, const Poly &b) {
        vector<Z> res(max(a.size(), b.size()));
        for (int i = 0; i < int(res.size()); i++) {
            res[i] = a[i] - b[i];
        }
        return Poly(res);
    }

    friend Poly operator*(Poly a, Poly b) {
        if (a.size() == 0 || b.size() == 0) {
            return Poly();
        }
        static NTT<Z, rt> ntt;
        return ntt.multiply(a.a, b.a);
    }
    friend Poly operator*(Z a, Poly b) {
        for (int i = 0; i < int(b.size()); i++) {
            b[i] *= a;
        }
        return b;
    }
    friend Poly operator*(Poly a, Z b) {
        for (int i = 0; i < int(a.size()); i++) {
            a[i] *= b;
        }
        return a;
    }
    Poly &operator+=(Poly b) { return (*this) = (*this) + b; }
    Poly &operator-=(Poly b) { return (*this) = (*this) - b; }
    Poly &operator*=(Poly b) { return (*this) = (*this) * b; }
};
int main() {
    std::ios::sync_with_stdio(false);
    cin.tie(nullptr);

    constexpr int mod = 998244353;
    using Z = mint<mod>;
    using poly = Poly<Z, 3>;
    int n;
    cin >> n;

    vector<vector<int>> t(n + 1);
    for (int i = 2; i <= n; i++) {
        int p;
        cin >> p;
        t[i].push_back(p);
        t[p].push_back(i);
    }

    vector<int> siz(n + 1), son(n + 1), f(n + 1);

    function<void(int, int)> dfs = [&](int u, int fa) {
        siz[u] = 1;
        f[u] = fa;
        for (auto v : t[u]) {
            if (v == fa) continue;
            dfs(v, u);
            siz[u] += siz[v];
            if (siz[v] > siz[son[u]]) son[u] = v;
        }
    };
    dfs(1, 0);
    vector<poly> dp(n + 1);
    function<void(int)> dfs2 = [&](int u) {
        vector<int> lt;
        int now = u;
        while (now) {
            lt.push_back(now);
            now = son[now];
        }
        for (auto it : lt) {
            vector<poly> res;
            for (auto v : t[it]) {
                if (v == f[it] || v == son[it]) continue;
                dfs2(v);
                res.emplace_back(std::move(dp[v]));
                res.rbegin()->resize(res.rbegin()->size() + 1);
            }
            function<poly(int, int)> dc = [&](int l, int r) -> poly {
                if (r - l <= 1) {
                    if (l == r) return {1, 0};
                    return res[l];
                }
                int mid = (l + r) / 2;
                return dc(l, mid) * dc(mid, r);
            };
            dp[it] = dc(0, res.size());
        }
        vector<poly> res;
        for (auto &&it : lt) res.emplace_back(std::move(dp[it]));

        function<pair<poly, poly>(int, int)> dc2 =
            [&](int l, int r) -> pair<poly, poly> {
            if (r - l == 1) {
                return {{0, 1}, res[l]};
            }
            int mid = (l + r) / 2;
            auto [al, bl] = dc2(l, mid);
            auto [ar, br] = dc2(mid, r);
            return {ar * bl + al, bl * br};
        };

        auto [x, y] = dc2(0, res.size());
        dp[u] = x + y;
        dp[u].resize(dp[u].size() + 1);
    };
    dfs2(1);
    for (int i = 1; i <= n; i++) cout << dp[1][i] << '\n';

    return 0;
}

2. AtCoder Beginner Contest 268 Ex

题意:Link
做法:考虑对T建出AC自动机,然后处理出S中每个位置i的最短能匹配T中的长度的区间,这个东西可以在fail树上预处理。然后问题转换成了区间选点使得每个区间至少有一个点的问题。贪心即可。
代码如下:

Code
#include <bits/stdc++.h>

using namespace std;
using i64 = long long;

struct TrieNode {
    TrieNode() { id = 0, dep = 0, nxt = array<int, 26>(); };
    TrieNode(int _id, int _dep) : id(_id), dep(_dep) {}
    int id;
    int dep;
    array<int, 26> nxt = {};
    int& operator[](const int x) { return this->nxt[x]; }
};
template <class Node>
struct trie {
    vector<Node> tr;
    trie() { tr.push_back(Node()); };

    int add(const string& s) {
        int n = s.size();
        int p = 0;
        for (int i = 0; i < n; i++) {
            int c = s[i] - 'a';
            if (!tr[p][c]) {
                tr[p][c] = tr.size();
                tr.emplace_back(tr[p][c], tr[p].dep + 1);
            }
            p = tr[p][c];
        }
        return p;
    }

    int size() const { return tr.size(); }
};

template <class Node>
struct ACAutomaton : public trie<Node> {
    vector<int> fail;
    ACAutomaton() { this->tr.push_back(Node()); };

    void BuildAC() {
        fail.resize(this->tr.size());
        queue<int> Q;
        for (int i = 0; i < 26; i++)
            if (this->tr[0][i]) Q.push(this->tr[0][i]);
        while (!Q.empty()) {
            int u = Q.front();
            Q.pop();
            for (int i = 0; i < 26; i++) {
                if (this->tr[u][i])
                    fail[this->tr[u][i]] = this->tr[fail[u]][i],
                    Q.push(this->tr[u][i]);
                else
                    this->tr[u][i] = this->tr[fail[u]][i];
            }
        }
        return;
    }
};
int main() {
    std::ios::sync_with_stdio(false);
    cin.tie(nullptr);

    string s;
    cin >> s;

    int m;
    cin >> m;
    ACAutomaton<TrieNode> ac;
    vector<int> pos(m + 1);
    for (int i = 1; i <= m; i++) {
        string res;
        cin >> res;
        pos[i] = ac.add(res);
    }
    ac.BuildAC();
    vector<int> vis(ac.size());
    for (int i = 1; i <= m; i++) vis[pos[i]] = 1;
    vector<vector<int>> adj(ac.size());
    for (int i = 0; i < ac.size(); i++) {
        if (ac.fail[i] != i) adj[ac.fail[i]].push_back(i);
    }
    vector<int> to(ac.size(), -1);
    queue<int> q;
    q.push(0);
    while (!q.empty()) {
        int u = q.front();
        q.pop();
        if ((to[u] == -1) && vis[u]) {
            to[u] = u;
        }
        for (auto v : adj[u]) {
            to[v] = to[u];
            q.push(v);
        }
    }

    int p = 0;
    vector<pair<int, int>> seg;

    for (int i = 0; i < s.size(); i++) {
        p = ac.tr[p][s[i] - 'a'];
        if (to[p] == -1) continue;
        seg.emplace_back(i - ac.tr[to[p]].dep + 1, i);
    }

    sort(seg.begin(), seg.end(),
         [&](const pair<int, int>& a, const pair<int, int>& b) {
             return a.second < b.second;
         });

    int ans = 0;

    int nowr = -1;
    for (auto [l, r] : seg) {
        if (l > nowr) {
            nowr = max(nowr, r);
            ans++;
        }
    }

    cout << ans << endl;

    return 0;
}

3. AtCoder Regular Contest 151 B

题意:Link
做法:可以一步一步地推式子,但是有个更加巧妙的做法。考虑到字典序严格小于和大于之间可以通过i\leftrightarrow p[i]建立双射,于是只需要计算字典序相等的情况即可。显然环内取相同数字可以取到。设有c个环,则答案为:

ans=\frac{m^n-m^c}{2}

Code
#include <bits/stdc++.h>

using namespace std;
using i64 = long long;

template <class T>
T power(T a, int b) {
    T res = 1;
    for (; b; b /= 2, a *= a) {
        if (b % 2) {
            res *= a;
        }
    }
    return res;
}

template <int mod>
struct mint {
    int x;
    mint() : x(0) {}
    mint(int64_t y) : x(y >= 0 ? y % mod : (mod - (-y) % mod) % mod) {}
    mint &operator+=(const mint &p) {
        if ((x += p.x) >= mod) x -= mod;
        return *this;
    }
    mint &operator-=(const mint &p) {
        if ((x += mod - p.x) >= mod) x -= mod;
        return *this;
    }
    mint &operator*=(const mint &p) {
        x = (int)(1LL * x * p.x % mod);
        return *this;
    }
    mint &operator/=(const mint &p) {
        *this *= p.inverse();
        return *this;
    }
    mint operator-() const { return mint(-x); }
    mint operator+(const mint &p) const { return mint(*this) += p; }
    mint operator-(const mint &p) const { return mint(*this) -= p; }
    mint operator*(const mint &p) const { return mint(*this) *= p; }
    mint operator/(const mint &p) const { return mint(*this) /= p; }
    bool operator==(const mint &p) const { return x == p.x; }
    bool operator!=(const mint &p) const { return x != p.x; }
    mint inverse() const {
        int a = x, b = mod, u = 1, v = 0, t;
        while (b > 0) {
            t = a / b;
            swap(a -= t * b, b);
            swap(u -= t * v, v);
        }
        return mint(u);
    }
    friend ostream &operator<<(ostream &os, const mint &p) { return os << p.x; }
    friend istream &operator>>(istream &is, mint &a) {
        int64_t t;
        is >> t;
        a = mint<mod>(t);
        return (is);
    }
    int get() const { return x; }
    static constexpr int get_mod() { return mod; }
};

constexpr int mod = 998244353;
using Z = mint<mod>;
int main() {
    std::ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n, m;
    cin >> n >> m;
    if (m == 1) {
        cout << 0 << endl;
        return 0;
    }

    vector<int> p(n + 1);
    for (int i = 1; i <= n; i++) cin >> p[i];
    vector<Z> fac(n + 1);
    fac[0] = Z(1);
    for (int i = 1; i <= n; i++) fac[i] = fac[i - 1] * Z(i);
    Z ans = 1;
    vector<int> vis(n + 1);
    int g = 0;
    for (int i = 1; i <= n; i++) {
        if (!vis[i]) {
            int now = i;
            vis[now] = 1;
            int cnt = 1;
            while (p[now] != i) {
                now = p[now];
                vis[now] = 1;
                cnt++;
            }
            g++;
        }
    }
    cout << Z(power<Z>(m, n) - power<Z>(m, g)) / Z(2);
    return 0;
}