Gauss's blog

Gauss's blog

I’d go back to December turn around and make it alright.

「学习笔记」常系数线性齐次递推

posted on 2021-03-31 09:59:24 | under 线性代数 |

问题描述

给定常系数线性齐次递推关系$$f_n=\sum_{i=1}^ka_if_{n-i},n\ge k$$ 以及初始值$$f_0=b_0,\cdots,f_{k-1}=b_{k-1}$$ 求$f_n$

题解

根据递推关系,构造转移矩阵$$A=\begin{pmatrix}0&1&0&\cdots&0\\0&0&1&\cdots&0\\\vdots&\vdots&\vdots&\ddots&\vdots\\0&0&0&\cdots&1\\a_k&a_{k-1}&a_{k-2}&\cdots&a_1\end{pmatrix}$$ 设列向量${\bm F}=\begin{pmatrix}b_0&b_1&\cdots&b_{k-1}\end{pmatrix}^{\text{T}}$.

答案即为$(A^n{\bm F})_0$,矩阵快速幂可以做到$\mathcal O(k^3\log n)$.

接下来,考虑矩阵$A$的特征多项式$$\det(\lambda I-A)=\begin{vmatrix}\lambda&-1&0&\cdots&0\\0&\lambda&-1&\cdots&0\\\vdots&\vdots&\vdots&\ddots&\vdots\\0&0&0&\cdots&-1\\-a_k&-a_{k-1}&-a_{k-2}&\cdots&\lambda-a_1\end{vmatrix}=\lambda^k-a_1\lambda^{k-1}-\cdots-a_k$$ 记$f(x)=x^k-a_1x^{k-1}-\cdots-a_k$,根据Cayley-Hamilton定理,有$f(A)=0$.

设多项式$x^n$对$f(x)$取模的结果为$r(x)$,易知$A^n=r(A)$.

考虑到$r(x)$是关于$x$的$k-1$次多项式,故设$[x^i]r(x)=c_i,i=0,\cdots,k-1$,那么有$$A^n=c_0+c_1A+\cdots+c_{k-1}A^{k-1}$$$$\begin{aligned}(A^n{\bm F})_0&=(c_0{\bm F}+c_1A{\bm F}+\cdots+c_{k-1}A_{k-1}{\bm F})_0\\&=c_0{\bm F}_0+c_1(A{\bm F})_0+\cdots+(A^{k-1}{\bm F})_0\\&=c_0b_0+c_1b_1+\cdots+c_{k-1}b_{k-1}\end{aligned}$$ 多项式取模的部分可以快速幂以确保多项式的长度在$k$以内,时间复杂度$\mathcal O(k\log k\log n)$.

一些优化

因为多项式取模的常数实在太大,所以我第一遍交的时候TLE了.

参考了同机房某位同学的代码,发现了一个小优化,即多项式取模的时候,每次的除数都是一样的($f(x)$),所以$f(x)$系数翻转后的逆可以只算一遍.

代码

#include <cstdio>
#include <vector>
#include <iostream>
#include <algorithm>

using namespace std;
const int N = 32000;
const int mod = 998244353;
int rev[N << 2 | 1], n, k, ans, a[N + 1], f[N + 1];;

int pow(int a, int b)
{
    int ret = 1;
    for(; b; b >>= 1, a = 1ll * a * 1ll * a % mod)
        if(b & 1) ret = 1ll * ret * 1ll * a % mod;
    return ret;
}
struct poly
{
    vector<int> a;
    poly(int x = 0)
    {
        if(x) a.push_back(x);
    }
    int size()
    {
        return a.size();
    }
    void resize(int len)
    {
        a.resize(len);
    }
    void shrink()
    {
        for(; !a.empty() && !a.back(); a.pop_back());
    }
    poly rever()
    {
        poly ret;
        ret.resize((*this).size());
        for(int i = 0; i < ret.size(); i++)
            ret.a[i] = a[ret.size() - i - 1];
        return ret;
    }
    void print()
    {
        for(int i = 0; i < size(); i++)
            printf("%d ", a[i]);
        puts("");
    }
    int operator[](const int& x)const
    {
        if(x < 0 || x >= a.size())
            return 0;
        return a[x];
    }
    poly diff()
    {
        if(a.empty())
            return poly();
        poly ret;
        ret.resize(a.size() - 1);
        for(int i = 0; i < ret.size(); i++)
            ret.a[i] = 1ll * (i + 1) * 1ll * a[i + 1] % mod;
        return ret;
    }
    void NTT(int type = 1)
    {
        static const int g = 3;
        int n = size();
        for(int i = 0; i < n; i++)
            if(i < rev[i]) swap(a[i], a[rev[i]]);
        for(int mid = 1; mid < n; mid <<= 1)
        {
            int step = pow(g, (mod - 1) / (mid << 1));
            for(int i = 0; i < n; i += (mid << 1))
            {
                for(int j = 0, omega = 1; j < mid; j++, omega = 1ll * omega * 1ll * step % mod)
                {
                    int x = a[i + j], y = 1ll * omega * 1ll * a[i + j + mid] % mod;
                    a[i + j] = (x + y) % mod;
                    a[i + j + mid] = (x - y + mod) % mod;
                }
            }
        }
        if(type == -1)
        {
            reverse(a.begin() + 1, a.end());
            int inv = pow(n, mod - 2);
            for(int i = 0; i < n; i++)
                a[i] = 1ll * a[i] * 1ll * inv % mod;
        }
    }
    friend inline bool operator == (poly f, poly g)
    {
        f.shrink(), g.shrink();
        if(f.size() != g.size()) return 0;
        for(int i = 0; i < f.size(); i++)
            if(f[i] != g[i]) return 0;
        return 1;
    }
    friend inline bool operator != (poly f, poly g)
    {
        return !(f == g);
    }
    friend inline poly operator + (poly f, poly g)
    {
        f.resize(max(f.size(), g.size()));
        for(int i = 0; i < f.size(); i++)
            f.a[i] = (f[i] + g[i]) % mod;
        return f;
    }
    friend inline poly operator - (poly f, poly g)
    {
        f.resize(max(f.size(), g.size()));
        for(int i = 0; i < f.size(); i++)
            f.a[i] = (f[i] - g[i] + mod) % mod;
        return f;
    }
    friend inline poly operator * (poly f, poly g)
    {
        if(!f.size() || !g.size())
            return poly();
        int lim = 1;
        for(; lim < f.size() + g.size() - 1; lim <<= 1);
        for(int i = 0; i < lim; i++)
        {
            rev[i] = rev[i >> 1] >> 1;
            if(i & 1)
                rev[i] |= lim >> 1;
        }
        f.resize(lim), g.resize(lim);
        f.NTT(), g.NTT();
        for(int i = 0; i < lim; i++)
            f.a[i] = 1ll * f[i] * 1ll * g[i] % mod;
        f.NTT(-1), f.shrink();
        return f;
    }
    friend inline poly operator / (poly f, poly g)
    {
        static poly h = g.rever().inv(k + 1);
        if(f.size() < g.size())
            return poly();
        return (f.rever().modx(f.size() - g.size() + 1) * h.modx(f.size() - g.size() + 1)).modx(f.size() - g.size() + 1).rever();
    }
    friend inline poly operator % (poly f, poly g)
    {
        if(f.size() < g.size())
            return f;
        poly q = f / g;
        return (f.modx(g.size() - 1) - g.modx(g.size() - 1) * q.modx(g.size() - 1)).modx(g.size() - 1);
    }
    poly modx(int n)
    {
        poly g = *this;
        g.resize(min(g.size(), n));
        return g;
    }
    poly inv(int n)
    {
        poly f(pow(a[0], mod - 2));
        for(int len = 1; len < n; len <<= 1)
            f = (f * (poly(2) - f * (*this).modx(len << 1))).modx(len << 1);
        return f.modx(n);
    }
}p, base, ret;
int read()
{
    int ans = 0, neg = 1; char ch;
    for(; !isdigit(ch = getchar()); ) if(ch == '-') neg *= -1;
    for(; isdigit(ch); ans = (ans << 1) + (ans << 3) + (ch ^ 48), ch = getchar());
    return ans * neg;
}
int main()
{
    n = read(), k = read();
    for(int i = 1; i <= k; i++)
    {
        a[i] = read();
        if(a[i] < 0) a[i] += mod;
    }
    for(int i = 0; i < k; i++)
    {
        f[i] = read();
        if(f[i] < 0) f[i] += mod;
    }
    p.resize(k + 1), p.a[k] = 1;
    for(int i = 1; i <= k; i++)
    {
        if(a[i])
            p.a[k - i] = mod - a[i];
    }
    if(n & 1)
    {
        for(int i = 0; i <= k; i++)
        {
            if(p[i])
                p.a[i] = mod - p[i];
        }
    }
    base.resize(2), base.a[1] = 1, ret.resize(1), ret.a[0] = 1;
    for(; n; n >>= 1, base = base * base % p)
        if(n & 1) ret = ret * base % p;
    for(int i = 0; i < k; i++)
        ans = (1ll * ans + 1ll * f[i] * 1ll * ret[i]) % mod;
    printf("%d\n", ans);
    return 0;
}