问题描述
给定常系数线性齐次递推关系$$f_n=\sum_{i=1}^ka_if_{n-i},n\ge k$$ 以及初始值$$f_0=b_0,\cdots,f_{k-1}=b_{k-1}$$ 求$f_n$
题解
根据递推关系,构造转移矩阵$$A=\begin{pmatrix}0&1&0&\cdots&0\\0&0&1&\cdots&0\\\vdots&\vdots&\vdots&\ddots&\vdots\\0&0&0&\cdots&1\\a_k&a_{k-1}&a_{k-2}&\cdots&a_1\end{pmatrix}$$ 设列向量${\bm F}=\begin{pmatrix}b_0&b_1&\cdots&b_{k-1}\end{pmatrix}^{\text{T}}$.
答案即为$(A^n{\bm F})_0$,矩阵快速幂可以做到$\mathcal O(k^3\log n)$.
接下来,考虑矩阵$A$的特征多项式$$\det(\lambda I-A)=\begin{vmatrix}\lambda&-1&0&\cdots&0\\0&\lambda&-1&\cdots&0\\\vdots&\vdots&\vdots&\ddots&\vdots\\0&0&0&\cdots&-1\\-a_k&-a_{k-1}&-a_{k-2}&\cdots&\lambda-a_1\end{vmatrix}=\lambda^k-a_1\lambda^{k-1}-\cdots-a_k$$ 记$f(x)=x^k-a_1x^{k-1}-\cdots-a_k$,根据Cayley-Hamilton定理,有$f(A)=0$.
设多项式$x^n$对$f(x)$取模的结果为$r(x)$,易知$A^n=r(A)$.
考虑到$r(x)$是关于$x$的$k-1$次多项式,故设$[x^i]r(x)=c_i,i=0,\cdots,k-1$,那么有$$A^n=c_0+c_1A+\cdots+c_{k-1}A^{k-1}$$$$\begin{aligned}(A^n{\bm F})_0&=(c_0{\bm F}+c_1A{\bm F}+\cdots+c_{k-1}A_{k-1}{\bm F})_0\\&=c_0{\bm F}_0+c_1(A{\bm F})_0+\cdots+(A^{k-1}{\bm F})_0\\&=c_0b_0+c_1b_1+\cdots+c_{k-1}b_{k-1}\end{aligned}$$ 多项式取模的部分可以快速幂以确保多项式的长度在$k$以内,时间复杂度$\mathcal O(k\log k\log n)$.
一些优化
因为多项式取模的常数实在太大,所以我第一遍交的时候TLE了.
参考了同机房某位同学的代码,发现了一个小优化,即多项式取模的时候,每次的除数都是一样的($f(x)$),所以$f(x)$系数翻转后的逆可以只算一遍.
代码
#include <cstdio>
#include <vector>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 32000;
const int mod = 998244353;
int rev[N << 2 | 1], n, k, ans, a[N + 1], f[N + 1];;
int pow(int a, int b)
{
int ret = 1;
for(; b; b >>= 1, a = 1ll * a * 1ll * a % mod)
if(b & 1) ret = 1ll * ret * 1ll * a % mod;
return ret;
}
struct poly
{
vector<int> a;
poly(int x = 0)
{
if(x) a.push_back(x);
}
int size()
{
return a.size();
}
void resize(int len)
{
a.resize(len);
}
void shrink()
{
for(; !a.empty() && !a.back(); a.pop_back());
}
poly rever()
{
poly ret;
ret.resize((*this).size());
for(int i = 0; i < ret.size(); i++)
ret.a[i] = a[ret.size() - i - 1];
return ret;
}
void print()
{
for(int i = 0; i < size(); i++)
printf("%d ", a[i]);
puts("");
}
int operator[](const int& x)const
{
if(x < 0 || x >= a.size())
return 0;
return a[x];
}
poly diff()
{
if(a.empty())
return poly();
poly ret;
ret.resize(a.size() - 1);
for(int i = 0; i < ret.size(); i++)
ret.a[i] = 1ll * (i + 1) * 1ll * a[i + 1] % mod;
return ret;
}
void NTT(int type = 1)
{
static const int g = 3;
int n = size();
for(int i = 0; i < n; i++)
if(i < rev[i]) swap(a[i], a[rev[i]]);
for(int mid = 1; mid < n; mid <<= 1)
{
int step = pow(g, (mod - 1) / (mid << 1));
for(int i = 0; i < n; i += (mid << 1))
{
for(int j = 0, omega = 1; j < mid; j++, omega = 1ll * omega * 1ll * step % mod)
{
int x = a[i + j], y = 1ll * omega * 1ll * a[i + j + mid] % mod;
a[i + j] = (x + y) % mod;
a[i + j + mid] = (x - y + mod) % mod;
}
}
}
if(type == -1)
{
reverse(a.begin() + 1, a.end());
int inv = pow(n, mod - 2);
for(int i = 0; i < n; i++)
a[i] = 1ll * a[i] * 1ll * inv % mod;
}
}
friend inline bool operator == (poly f, poly g)
{
f.shrink(), g.shrink();
if(f.size() != g.size()) return 0;
for(int i = 0; i < f.size(); i++)
if(f[i] != g[i]) return 0;
return 1;
}
friend inline bool operator != (poly f, poly g)
{
return !(f == g);
}
friend inline poly operator + (poly f, poly g)
{
f.resize(max(f.size(), g.size()));
for(int i = 0; i < f.size(); i++)
f.a[i] = (f[i] + g[i]) % mod;
return f;
}
friend inline poly operator - (poly f, poly g)
{
f.resize(max(f.size(), g.size()));
for(int i = 0; i < f.size(); i++)
f.a[i] = (f[i] - g[i] + mod) % mod;
return f;
}
friend inline poly operator * (poly f, poly g)
{
if(!f.size() || !g.size())
return poly();
int lim = 1;
for(; lim < f.size() + g.size() - 1; lim <<= 1);
for(int i = 0; i < lim; i++)
{
rev[i] = rev[i >> 1] >> 1;
if(i & 1)
rev[i] |= lim >> 1;
}
f.resize(lim), g.resize(lim);
f.NTT(), g.NTT();
for(int i = 0; i < lim; i++)
f.a[i] = 1ll * f[i] * 1ll * g[i] % mod;
f.NTT(-1), f.shrink();
return f;
}
friend inline poly operator / (poly f, poly g)
{
static poly h = g.rever().inv(k + 1);
if(f.size() < g.size())
return poly();
return (f.rever().modx(f.size() - g.size() + 1) * h.modx(f.size() - g.size() + 1)).modx(f.size() - g.size() + 1).rever();
}
friend inline poly operator % (poly f, poly g)
{
if(f.size() < g.size())
return f;
poly q = f / g;
return (f.modx(g.size() - 1) - g.modx(g.size() - 1) * q.modx(g.size() - 1)).modx(g.size() - 1);
}
poly modx(int n)
{
poly g = *this;
g.resize(min(g.size(), n));
return g;
}
poly inv(int n)
{
poly f(pow(a[0], mod - 2));
for(int len = 1; len < n; len <<= 1)
f = (f * (poly(2) - f * (*this).modx(len << 1))).modx(len << 1);
return f.modx(n);
}
}p, base, ret;
int read()
{
int ans = 0, neg = 1; char ch;
for(; !isdigit(ch = getchar()); ) if(ch == '-') neg *= -1;
for(; isdigit(ch); ans = (ans << 1) + (ans << 3) + (ch ^ 48), ch = getchar());
return ans * neg;
}
int main()
{
n = read(), k = read();
for(int i = 1; i <= k; i++)
{
a[i] = read();
if(a[i] < 0) a[i] += mod;
}
for(int i = 0; i < k; i++)
{
f[i] = read();
if(f[i] < 0) f[i] += mod;
}
p.resize(k + 1), p.a[k] = 1;
for(int i = 1; i <= k; i++)
{
if(a[i])
p.a[k - i] = mod - a[i];
}
if(n & 1)
{
for(int i = 0; i <= k; i++)
{
if(p[i])
p.a[i] = mod - p[i];
}
}
base.resize(2), base.a[1] = 1, ret.resize(1), ret.a[0] = 1;
for(; n; n >>= 1, base = base * base % p)
if(n & 1) ret = ret * base % p;
for(int i = 0; i < k; i++)
ans = (1ll * ans + 1ll * f[i] * 1ll * ret[i]) % mod;
printf("%d\n", ans);
return 0;
}