题意:给定n,k,求解长度为k的数列a,满足a_{i} \mid a_{i+1},其中n,k\leq 10^9。
首先有个明显的dp方程,其中dp(n,k)表示考虑到第k个数末尾的数为n的合法数列的方案数:
dp(n,k)=\sum_{d\mid n}dp(d,k-1)
考虑写成迪利克雷卷积形式:
dp(n,k)=dp(n,k-1)*I
当k等于1时,边界取值均为1,因此本题实际上要求的东西其实是:
\sum_{i=1}^{n}I^k(i)
其中I为置1函数,幂次代表迪利克雷卷积,其实就是I函数迪利克雷卷积k次。
由于I为积性函数,I^k亦为积性函数。根据经典方法,只需要知道I^k(p^t)的取值即可套用min_25筛求解本题。
考虑贝尔级数,I_p(x)=\frac{1}{1-x},故I^k_p(x)=(1-x)^{-k},p^t的取值其实是贝尔级数[x^t]的取值,根据广义二项式定理,为\binom{k+t-1}{t}。
比赛的时候我没考虑贝尔级数,手推一下也并不难,本质上是个k次前缀和。
Code
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
template <int mod>
struct mint {
int x;
mint() : x(0) {}
mint(i64 y) : x(y >= 0 ? y % mod : (mod - (-y) % mod) % mod) {}
mint &operator+=(const mint &p) {
if ((x += p.x) >= mod) x -= mod;
return *this;
}
mint &operator-=(const mint &p) {
if ((x += mod - p.x) >= mod) x -= mod;
return *this;
}
mint &operator*=(const mint &p) {
x = (int)(1ll * x * p.x % mod);
return *this;
}
mint &operator/=(const mint &p) {
*this *= p.inverse();
return *this;
}
mint operator-() const { return mint(-x); }
mint operator+(const mint &p) const { return mint(*this) += p; }
mint operator-(const mint &p) const { return mint(*this) -= p; }
mint operator*(const mint &p) const { return mint(*this) *= p; }
mint operator/(const mint &p) const { return mint(*this) /= p; }
bool operator==(const mint &p) const { return x == p.x; }
bool operator!=(const mint &p) const { return x != p.x; }
mint inverse() const {
int a = x, b = mod, u = 1, v = 0, t;
while (b > 0) {
t = a / b;
swap(a -= t * b, b);
swap(u -= t * v, v);
}
return mint(u);
}
friend ostream &operator<<(ostream &os, const mint &p) { return os << p.x; }
friend istream &operator>>(istream &is, mint &a) {
i64 t;
is >> t;
a = mint<mod>(t);
return is;
}
int get() const { return x; }
static constexpr int get_mod() { return mod; }
};
i64 n, k;
vector<int> prime_en(int N) {
vector<int> sieve(N / 3 + 1, 1);
for (int p = 5, d = 4, i = 1, sqn = sqrt(N); p <= sqn;
p += d = 6 - d, i++) {
if (!sieve[i]) continue;
for (int q = p * p / 3, r = d * p / 3 + (d * p % 3 == 2), s = 2 * p,
qe = sieve.size();
q < qe; q += r = s - r)
sieve[q] = 0;
}
vector<int> ret{2, 3};
for (int p = 5, d = 4, i = 1; p <= N; p += d = 6 - d, i++) {
if (sieve[i]) ret.push_back(p);
}
while (!ret.empty() && ret.back() > N) ret.pop_back();
return ret;
}
constexpr int mod = 1e9 + 7;
using Z = mint< mod >;
constexpr int N = 60;
vector<Z> inv(N + 1);
Z f(i64 p, i64 t) {
/*
C(k+t-1,t)
*/
Z res = 1;
for (i64 base = k + t - 1, now = 1; now <= t; now++, base--) {
res *= Z(base);
}
for (int i = 1; i <= t; i++) res *= inv[i];
return res;
}
template < typename T, T (*f)(i64, i64) >
struct mf {
i64 M, sq, s;
vector< int> p;
int ps;
vector<T> buf;
T ans;
mf(i64 m) : M(m) {
sq = sqrt(M);
while (sq * sq > M) sq--;
while ((sq + 1) * (sq + 1) <= M) sq++;
if (M != 0) {
i64 hls = md(M, sq);
if (hls != 1 && md(M, hls - 1) == sq) hls--;
s = hls + sq;
p = prime_en(sq);
ps = p.size();
ans = T{};
}
}
vector<T> pi_table() {
if (M == 0) return {};
i64 hls = md(M, sq);
if (hls != 1 && md(M, hls - 1) == sq) hls--;
vector<i64> hl(hls);
for (int i = 1; i < hls; ++i) hl[i] = md(M, i) - 1;
vector<int> hs(sq + 1);
iota(begin(hs), end(hs), -1);
int pi = 0;
for (auto &x : p) {
i64 x2 = i64(x) * x;
i64 imax = min<i64>(hls, md(M, x2) + 1);
for (i64 i = 1, ix = x; i < imax; ++i, ix += x) {
hl[i] -= (ix < hls ? hl[ix] : hs[md(M, ix)]) - pi;
}
for (int n = sq; n >= x2; --n) hs[n] -= hs[md(n, x)] - pi;
pi++;
}
vector<T> res;
res.reserve(2 * sq + 10);
for (auto &x : hl) res.push_back(x);
for (int i = hs.size(); --i;) res.push_back(hs[i]);
return res;
}
vector<T> prime_sum_table() {
if (M == 0) return {};
i64 hls = md(M, sq);
if (hls != 1 && md(M, hls - 1) == sq) hls--;
vector<T> h(s);
T inv2 = T{2}.inverse();
for (int i = 1; i < hls; i++) {
T x = md(M, i);
h[i] = x * (x + 1) * inv2 - 1;
}
for (int i = 1; i <= sq; i++) {
T x = i;
h[s - i] = x * (x + 1) / 2 - 1;
}
for (auto &x : p) {
T xt = x;
T pi = h[s - x + 1];
i64 x2 = i64(x) * x;
i64 imax = min<i64>(hls, md(M, x2) + 1);
i64 ix = x;
for (i64 i = 1; i < imax; ++i, ix += x) {
h[i] -= ((ix < hls ? h[ix] : h[s - md(M, ix)]) - pi) * xt;
}
for (int n = sq; n >= x2; n--) {
h[s - n] -= (h[s - md(n, x)] - pi) * xt;
}
}
return h;
}
void dfs(int i, int c, i64 prod, T cur) {
ans += cur * f(p[i], c + 1);
i64 lim = md(M, prod);
if (lim >= 1ll * p[i] * p[i]) dfs(i, c + 1, p[i] * prod, cur);
cur *= f(p[i], c);
ans += cur * (buf[idx(lim)] - buf[idx(p[i])]);
int j = i + 1;
for (; j < ps && 1ll * p[j] * p[j] * p[j] <= lim; j++) {
dfs(j, 1, prod * p[j], cur);
}
for (; j < ps && 1ll * p[j] * p[j] <= lim; j++) {
T sm = f(p[j], 2);
int id1 = idx(md(lim, p[j])), id2 = idx(p[j]);
sm += f(p[j], 1) * (buf[id1] - buf[id2]);
ans += cur * sm;
}
}
T run(vector<T> &fprime) {
if (M == 0) return {};
set_buf(fprime);
ans = buf[idx(M)] + 1;
for (int i = 0; i < ps; i++) dfs(i, 1, p[i], 1);
return ans;
}
i64 md(i64 n, i64 d) { return double(n) / d; }
i64 idx(i64 n) { return n <= sq ? s - n : md(M, n); }
void set_buf(vector<T> &_buf) { swap(buf, _buf); }
};
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
for (int i = 1; i <= 60; i++) {
inv[i] = Z(i).inverse();
}
cin >> n >> k;
mf<Z, f> solve(n);
auto q = solve.pi_table();
for (auto &qq : q) qq *= Z(k);
auto ans = solve.run(q);
cout << ans;
return 0;
}
近期评论