Solution For ICPC Online#1 Problem F

题意:给定n,k,求解长度为k的数列a,满足a_{i} \mid a_{i+1},其中n,k\leq 10^9

首先有个明显的dp方程,其中dp(n,k)表示考虑到第k个数末尾的数为n的合法数列的方案数:

dp(n,k)=\sum_{d\mid n}dp(d,k-1)

考虑写成迪利克雷卷积形式:

dp(n,k)=dp(n,k-1)*I

k等于1时,边界取值均为1,因此本题实际上要求的东西其实是:

\sum_{i=1}^{n}I^k(i)

其中I为置1函数,幂次代表迪利克雷卷积,其实就是I函数迪利克雷卷积k次。
由于I为积性函数,I^k亦为积性函数。根据经典方法,只需要知道I^k(p^t)的取值即可套用min_25筛求解本题。
考虑贝尔级数,I_p(x)=\frac{1}{1-x},故I^k_p(x)=(1-x)^{-k}p^t的取值其实是贝尔级数[x^t]的取值,根据广义二项式定理,为\binom{k+t-1}{t}

比赛的时候我没考虑贝尔级数,手推一下也并不难,本质上是个k次前缀和。

Code
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
template <int mod>
struct mint {
    int x;
    mint() : x(0) {}
    mint(i64 y) : x(y >= 0 ? y % mod : (mod - (-y) % mod) % mod) {}
    mint &operator+=(const mint &p) {
        if ((x += p.x) >= mod) x -= mod;
        return *this;
    }
    mint &operator-=(const mint &p) {
        if ((x += mod - p.x) >= mod) x -= mod;
        return *this;
    }
    mint &operator*=(const mint &p) {
        x = (int)(1ll * x * p.x % mod);
        return *this;
    }
    mint &operator/=(const mint &p) {
        *this *= p.inverse();
        return *this;
    }
    mint operator-() const { return mint(-x); }
    mint operator+(const mint &p) const { return mint(*this) += p; }
    mint operator-(const mint &p) const { return mint(*this) -= p; }
    mint operator*(const mint &p) const { return mint(*this) *= p; }
    mint operator/(const mint &p) const { return mint(*this) /= p; }
    bool operator==(const mint &p) const { return x == p.x; }
    bool operator!=(const mint &p) const { return x != p.x; }
    mint inverse() const {
        int a = x, b = mod, u = 1, v = 0, t;
        while (b > 0) {
            t = a / b;
            swap(a -= t * b, b);
            swap(u -= t * v, v);
        }
        return mint(u);
    }
    friend ostream &operator<<(ostream &os, const mint &p) { return os << p.x; }
    friend istream &operator>>(istream &is, mint &a) {
        i64 t;
        is >> t;
        a = mint<mod>(t);
        return is;
    }
    int get() const { return x; }
    static constexpr int get_mod() { return mod; }
};
i64 n, k;
vector<int> prime_en(int N) {
    vector<int> sieve(N / 3 + 1, 1);
    for (int p = 5, d = 4, i = 1, sqn = sqrt(N); p <= sqn;
         p += d = 6 - d, i++) {
        if (!sieve[i]) continue;
        for (int q = p * p / 3, r = d * p / 3 + (d * p % 3 == 2), s = 2 * p,
                 qe = sieve.size();
             q < qe; q += r = s - r)
            sieve[q] = 0;
    }
    vector<int> ret{2, 3};
    for (int p = 5, d = 4, i = 1; p <= N; p += d = 6 - d, i++) {
        if (sieve[i]) ret.push_back(p);
    }
    while (!ret.empty() && ret.back() > N) ret.pop_back();
    return ret;
}
constexpr int mod = 1e9 + 7;
using Z = mint< mod >;
constexpr int N = 60;
vector<Z> inv(N + 1);
Z f(i64 p, i64 t) {
    /*
        C(k+t-1,t)
    */
    Z res = 1;
    for (i64 base = k + t - 1, now = 1; now <= t; now++, base--) {
        res *= Z(base);
    }
    for (int i = 1; i <= t; i++) res *= inv[i];
    return res;
}
template < typename T, T (*f)(i64, i64) >
struct mf {
    i64 M, sq, s;
    vector< int> p;
    int ps;
    vector<T> buf;
    T ans;
    mf(i64 m) : M(m) {
        sq = sqrt(M);
        while (sq * sq > M) sq--;
        while ((sq + 1) * (sq + 1) <= M) sq++;
        if (M != 0) {
            i64 hls = md(M, sq);
            if (hls != 1 && md(M, hls - 1) == sq) hls--;
            s = hls + sq;
            p = prime_en(sq);
            ps = p.size();
            ans = T{};
        }
    }
    vector<T> pi_table() {
        if (M == 0) return {};
        i64 hls = md(M, sq);
        if (hls != 1 && md(M, hls - 1) == sq) hls--;

        vector<i64> hl(hls);
        for (int i = 1; i < hls; ++i) hl[i] = md(M, i) - 1;

        vector<int> hs(sq + 1);
        iota(begin(hs), end(hs), -1);

        int pi = 0;
        for (auto &x : p) {
            i64 x2 = i64(x) * x;
            i64 imax = min<i64>(hls, md(M, x2) + 1);
            for (i64 i = 1, ix = x; i < imax; ++i, ix += x) {
                hl[i] -= (ix < hls ? hl[ix] : hs[md(M, ix)]) - pi;
            }
            for (int n = sq; n >= x2; --n) hs[n] -= hs[md(n, x)] - pi;
            pi++;
        }

        vector<T> res;
        res.reserve(2 * sq + 10);
        for (auto &x : hl) res.push_back(x);
        for (int i = hs.size(); --i;) res.push_back(hs[i]);
        return res;
    }
    vector<T> prime_sum_table() {
        if (M == 0) return {};
        i64 hls = md(M, sq);
        if (hls != 1 && md(M, hls - 1) == sq) hls--;
        vector<T> h(s);
        T inv2 = T{2}.inverse();
        for (int i = 1; i < hls; i++) {
            T x = md(M, i);
            h[i] = x * (x + 1) * inv2 - 1;
        }
        for (int i = 1; i <= sq; i++) {
            T x = i;
            h[s - i] = x * (x + 1) / 2 - 1;
        }
        for (auto &x : p) {
            T xt = x;
            T pi = h[s - x + 1];
            i64 x2 = i64(x) * x;
            i64 imax = min<i64>(hls, md(M, x2) + 1);
            i64 ix = x;
            for (i64 i = 1; i < imax; ++i, ix += x) {
                h[i] -= ((ix < hls ? h[ix] : h[s - md(M, ix)]) - pi) * xt;
            }
            for (int n = sq; n >= x2; n--) {
                h[s - n] -= (h[s - md(n, x)] - pi) * xt;
            }
        }
        return h;
    }
    void dfs(int i, int c, i64 prod, T cur) {
        ans += cur * f(p[i], c + 1);
        i64 lim = md(M, prod);
        if (lim >= 1ll * p[i] * p[i]) dfs(i, c + 1, p[i] * prod, cur);
        cur *= f(p[i], c);
        ans += cur * (buf[idx(lim)] - buf[idx(p[i])]);
        int j = i + 1;
        for (; j < ps && 1ll * p[j] * p[j] * p[j] <= lim; j++) {
            dfs(j, 1, prod * p[j], cur);
        }
        for (; j < ps && 1ll * p[j] * p[j] <= lim; j++) {
            T sm = f(p[j], 2);
            int id1 = idx(md(lim, p[j])), id2 = idx(p[j]);
            sm += f(p[j], 1) * (buf[id1] - buf[id2]);
            ans += cur * sm;
        }
    }
    T run(vector<T> &fprime) {
        if (M == 0) return {};
        set_buf(fprime);
        ans = buf[idx(M)] + 1;
        for (int i = 0; i < ps; i++) dfs(i, 1, p[i], 1);
        return ans;
    }

    i64 md(i64 n, i64 d) { return double(n) / d; }
    i64 idx(i64 n) { return n <= sq ? s - n : md(M, n); }
    void set_buf(vector<T> &_buf) { swap(buf, _buf); }
};
int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);

    for (int i = 1; i <= 60; i++) {
        inv[i] = Z(i).inverse();
    }

    cin >> n >> k;

    mf<Z, f> solve(n);
    auto q = solve.pi_table();

    for (auto &qq : q) qq *= Z(k);
    auto ans = solve.run(q);
    cout << ans;

    return 0;
}