从线性变换的角度看待 FWT

Gauss0320

2021-08-20 22:36:27

Solution

**upd on 2121/8/22:** 修改了一些符号/语言错误。 # Abstract 今天早上听本校高二的同学讲课,深受启发,下去自学了 FWT,决定以不同于大部分博客所提供的视角来介绍这个算法。 # About FWT FWT,即 Fast Walsh–Hadamard Transform,快速沃尔什变换。在 OI 中一般被用来处理位运算卷积问题。 # Analysis 设 $\mathbf a=(a_0,a_1,\cdots,a_{n-1})^{\text{T}}$ 是一个序列,考虑一个线性变换 $\mathscr P$,记这个线性变换在标准正交基 $\epsilon_1,\epsilon_2,\cdots,\epsilon_n$ 下的矩阵为 $\mathbf P$,$\mathbf a$ 在 $\mathscr P$ 的作用下得到的新的序列为 $\hat\mathbf a=(\hat a_0,\hat a_1,\cdots,\hat a_{n-1})^{\text{T}}$,i.e. $\hat \mathbf a=\mathbf P\mathbf a$。 现在定义两个序列的位运算卷积 $${\mathbf c}=\mathbf a*\mathbf b:\iff c_k=\sum_{i\oplus j=k}a_{i}b_j$$ 其中 $\oplus$ 代表某种位运算。 我们希望构造这样的**可逆**线性变换 $\mathscr P$,使得 $$\mathbf c=\mathbf a*\mathbf b\iff \hat c_i=\hat a_i\cdot\hat b_i,i=0,1,\cdots,n-1$$ 这样,我们就得到一条求位运算卷积的思路,对 $\mathbf a,\mathbf b$ 分别做变换 $\mathscr P$,对应位相乘得到 $\hat\mathbf c$,再做变换 $\mathscr P^{-1}$ 得到 $\mathbf c$,剩下的无非是加速这个变换的过程。 我们现在来分析一下这个线性变换 $\mathscr P$ 的性质,令 $\mathscr P$ 可逆 $$ \begin{aligned} &\hat c_k=\hat a_k\cdot\hat b_k\\ \iff &\sum_{u=0}^{n-1}P(k,u)c_u=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}P(k,i)P(k,j)a_ib_j\\ \iff &\sum_{u=0}^{n-1}\sum_{i\oplus j=u}P(k,u)a_ib_j=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}P(k,i)P(k,j)a_ib_j\\ \iff&\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}P(k,i\oplus j)a_ib_j=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}P(k,i)P(k,j)a_ib_j\\ \iff &P(k,i\oplus j)=P(k,i)P(k,j) \end{aligned} $$ 我们推出了这样一条性质,下面我们将揭秘在各种卷积问题中这个线性变换 $\mathscr P$ 究竟长什么样子。 ## FFT 注意到 $x^{i+j}=x^{i}x^j$,所以范德蒙矩阵是符合条件的,考虑到单位根的一些神奇性质,我们可以选取所有 $n$ 次单位根确定的范德蒙矩阵。因为 $\omega_n^0,\omega_n^1,\cdots,\omega_n^{n-1}$ 两两不同,所以这个矩阵是可逆的,其逆矩阵是 $$\begin{bmatrix}1&1&\cdots&1\\1&\omega_n^1&\cdots&\omega_n^{n-1}\\1&\vdots&\ddots&\vdots\\1&\omega_n^{n-1}&\cdots&\omega_n^{(n-1)^2}\end{bmatrix}^{-1}=\frac{1}{n}\begin{bmatrix}1&1&\cdots&1\\1&\omega_n^{-1}&\cdots&\omega_n^{-(n-1)}\\1&\vdots&\ddots&\vdots\\1&\omega_n^{-(n-1)}&\cdots&\omega_n^{-(n-1)^2}\end{bmatrix} $$ ## | 运算卷积 容易构造出 $P(i,j)=[i|j=i]$,那么 $$\hat c_k=\sum_{i=0}^{n-1}P(k,i)c_i=\sum_{k|i=k}c_i$$ 我们发现这实际上就是高维前缀和,其逆变换就是高维前缀差分。 代码如下 ```cpp void FWT_or(int *a) { for(int i = 0; i < n; i++) { for(int j = 0; j < (1 << n); j++) { if(j >> i & 1) a[j] = (a[j] + a[j - (1 << i)]) % mod; } } } void IFWT_or(int *a) { for(int i = 0; i < n; i++) { for(int j = 0; j < (1 << n); j++) { if(j >> i & 1) a[j] = (a[j] - a[j - (1 << i)] + mod) % mod; } } } ``` ## & 运算卷积 容易构造出 $P(i,j)=[i\text{\&}j=i]$,那么 $$\hat c_k=\sum_{i=0}^{n-1}P(k,i)c_i=\sum_{k\&i=k}c_i$$ 我们发现这实际上就是高维后缀和,其逆变换就是高维后缀差分。 代码如下 ```cpp void FWT_and(int *a) { for(int i = 0; i < n; i++) { for(int j = 0; j < (1 << n); j++) { if(!(j >> i & 1)) a[j] = (a[j] + a[j + (1 << i)]) % mod; } } } void IFWT_and(int *a) { for(int i = 0; i < n; i++) { for(int j = 0; j < (1 << n); j++) { if(!(j >> i & 1)) a[j] = (a[j] - a[j + (1 << i)] + mod) % mod; } } } ``` ## xor 运算卷积 xor 与前面提到的 & 和 | 略有不同,所以我们需要寻找另外的方式。 注意到我们有等式 $\text{popcount}(i)+\text{popcount}(j)\equiv \text{popcount}(i\text{ xor }j)\pmod{2}$。 而将其转化为乘法的关系仅需联想到 $-1$ 的幂,i.e. $(-1)^{\text{popcount}(i)}(-1)^{\text{popcount}(j)}=(-1)^{\text{popcount}(i\text{ xor } j)}$。 由于一些原因我们希望 $P(i,j)$ 和 $i,j$ 都有关系,而 xor 运算又分配于 & 运算,所以很容易构造出 $P(i,j)=(-1)^{\text{popcount}(i\&j)}$,那么 $$\hat c_k=\sum_{i=0}^{n-1}(-1)^{\text{popcount}(i\&k)}c_i$$ 考虑分治,设 $\mathbf c_0=(c_0,c_1,\cdots,c_{n/2-1}),\mathbf c_1=(c_{n/2},c_{n/2+1},\cdots,c_{n-1})$。 令 $k<\dfrac{n}{2}$,则有 $$\begin{aligned}\hat c_k&=\sum_{i=0}^{n/2-1}(-1)^{\text{popcount}(i\&k)}c_i+\sum_{i=0}^{n/2-1}(-1)^{\text{popcount}(i\&k)}c_{i+n/2}=\hat c_{0k}+\hat c_{1k}\\\hat c_{k+n/2}&=\sum_{i=0}^{n/2-1}(-1)^{\text{popcount}(i\&k)}c_i-\sum_{i=0}^{n/2-1}(-1)^{\text{popcount}(i\&k)}c_{i+n/2}=\hat c_{0k}-\hat c_{1k}\end{aligned}$$ 从而 $$\hat\mathbf c=\text{merge}(\hat\mathbf c_0+\hat\mathbf c_1,\hat\mathbf c_0-\hat\mathbf c_1)$$ 其中 $\text{merge}$ 表示将两个序列拼接起来。 为了得到它的逆变换,我们对矩阵 $\mathbf P$ 求逆,发现 $$\mathbf P^{-1}=\frac{1}{n}\mathbf P$$ 所以我们仅需做一次正变换,再将每个元素除以 $n$ 即可。 看到这里有没有觉得 FWT 和 FFT 很像?事实上,FWT 就是高维的循环卷积。 这部分的代码如下 ```cpp void FWT_xor(int *a) { for(int mid = 1; mid < (1 << n); mid <<= 1) { for(int i = 0; i < (1 << n); i += (mid << 1)) { for(int j = 0; j < mid; j++) { int x = a[i + j], y = a[i + j + mid]; a[i + j] = (x + y) % mod; a[i + j + mid] = (x - y + mod) % mod; } } } } void IFWT_xor(int *a) { for(int mid = 1; mid < (1 << n); mid <<= 1) { for(int i = 0; i < (1 << n); i += (mid << 1)) { for(int j = 0; j < mid; j++) { int x = a[i + j], y = a[i + j + mid]; a[i + j] = 1ll * (x + y) * 1ll * inv2 % mod; a[i + j + mid] = 1ll * (x - y + mod) * 1ll * inv2 % mod; } } } } ```