Mamba Architecture
Mamba 是一种 State Space Model/SSM,即状态空间模型。这是用以处理序列问题的一种非常优秀的模型。
其本质可以理解为“隐空间”/“状态空间”中的递推。
问题
给定时间信号输入序列 \(x(t)\),以及输出序列 \(y(t)\),尝试给出预测模型。 我们考虑一个隐状态空间 \(h(t)\in \mathbb{R}^N\),它由 \(x(t)\) 而来,控制了 \(y(t)\) 的输出。 从简化的角度考虑,给出线性模型:
\[\begin{align} h'(t) &= \mathbf{A}h(t)+\mathbf{B}x(t) \\ y(t) &= \mathbf{C}h'(t) \end{align} \]
其中,\(\mathbf{A}\),\(\mathbf{B}\),\(\mathbf{C}\) 三个矩阵分别代表 evolution parameter 和 projection parameter。第一个的作用是将隐空间中的状态逐步演化的最终态,后二者的作用是进行隐空间和实际输入/输出空间的映射。
从连续时间到离散时间
上面的模型显然是一个连续时间模型,也就意味着是一个微分方程。
但是在实际工程和应用中,我们不可能去解一个微分方程,因此我们考虑去将他离散化。其中,最重要的一种离散化方式被称作 "Zero-Order Hold/ZOH" 规则。
ZOH
ZOH 的核心假设非常易于理解:
Zero-Order Hold
在每个采样区间中,输入保持常数:
\[ u(t)=u_k\quad t\in [t_k,t_{k+1}) \]
其中,采样步长 \(\Delta>0\),\(t_k=k\Delta\)
我们希望得到离散递推公式: \[ x_{k+1}=\bar{A}x_k+\bar{B}u_k \]
所以,也就是去解这个微分方程(和普通的微分方程一样,先不考虑是矩阵的情况): \[ \frac{\mathrm{d}x}{\mathrm{d}t}=Ax+Bu \]
注意,这里 \(x\) 是状态,\(u\) 是输入。
将系统重写为 \[ \dot{x}(t)-Ax(t)=Bu(t) \]
解之即有: \[ \bar{A}=e^{\Delta A},\bar{B}=\int_{t_k}^{t_{k+1}}e^{A(t_{k+1}-\tau)}B\mathrm{d}\tau \]
若 \(A\) 可逆,则 \(B\) 有闭式解 \(A^{-1}(e^{A\Delta}-I)B\)
矩阵指数
\(B=e^{A}\) 并不是 \(B_{ij}=\exp\{a_{ij}\}\),而是采用泰勒展开: \[ e^A=\sum_{i=0}^{\infin} \frac{A^k}{k!} \]
对于对角矩阵,退化为上面的直觉的形式: \[ A=\text{diag}(a_1,\dots,a_n) \Rightarrow e^A=\text{diag}(e^{a_1},\dots,e^{a_n}) \]
类似,对于可对角化矩阵,也有: \[ A=P \text{diag}(a_1,\dots,a_n) P^{-1} \Rightarrow e^A=P\text{diag}(e^{a_1},\dots,e^{a_n})P^{-1} \]
递推
基于 ZOH 假设,我们可以快速写出以下式子(\(M\) 是序列长度): \[ \begin{align} \bar{K} &= (C\bar{B}, C\bar{A}\bar{B}, C\bar{A}^2\bar{B},\dots, C\bar{A}^{M-1}\bar{B})\\ y &= x * \bar{K} \end{align} \]
一个 Mamba Block
一个 Mamba Block 做一次完整的我们上述的递推操作,也就是对于每个位置提取一次上下文理解。
和 Transformer 类似,我们也可以把多个 Mamba Block 叠起来,进行更深度的信息提取,也就是逐层抽象,越靠后的 Mamba Block 蕴含了越丰富的、越抽象的信息。
\[ x^{(0)}\xrightarrow{\text{Mamba Block 1}} x^{(1)}\xrightarrow{\text{Mamba Block 2}}\dots \xrightarrow{\text{Mamba Block M}} x^{(M)} \]
具体实现
Algorithm 1:Parameters Function
根据我们上文的描述,从给定的序列 \(\{x\}\) 中生成 \(\bar{A}, \bar{B}, C\)。
\[ \begin{aligned} \textbf{Algorithm 1: Parameters Function}\\ \textbf{Require: } x' \in \mathbb{R}^{(B,N,P)}\\ \textbf{Ensure: } \bar A \in \mathbb{R}^{(B,N,P,K)},\ \bar B \in \mathbb{R}^{(B,N,P,K)},\ C \in \mathbb{R}^{(B,N,K)}\\[4pt] \begin{array}{ll} 1: & B \in \mathbb{R}^{(B,N,K)} \leftarrow \mathrm{Linear}^{B}(x')\\ 2: & C \in \mathbb{R}^{(B,N,K)} \leftarrow \mathrm{Linear}^{C}(x')\\ 3: & \Delta \in \mathbb{R}^{(B,N,P)} \leftarrow \log\!\Big(1+\exp(\mathrm{Linear}^{\Delta}(x')+\mathrm{Parameter}^{\Delta})\Big)\\ 4: & \text{// } \mathrm{Parameter}^{A} \in \mathbb{R}^{(P,K)}\\ 5: & \bar A \in \mathbb{R}^{(B,N,P,K)} \leftarrow \Delta \otimes \mathrm{Parameter}^{A}\\ 6: & \bar B \in \mathbb{R}^{(B,N,P,K)} \leftarrow \Delta \otimes B\\ \textbf{Return:} & \bar A,\ \bar B,\ C \end{array} \end{aligned} \]
\[ \begin{aligned} \textbf{Algorithm 2: Mamba Block}\\ \textbf{Require: } T_{l-1} \in \mathbb{R}^{(B,N,C)}\\ \textbf{Ensure: } T_{l} \in \mathbb{R}^{(B,N,C)}\\[4pt] \begin{array}{ll} 1: & \text{// Apply layer normalization to } T_{l-1}\\ 2: & T'_{l-1} \in \mathbb{R}^{(B,N,C)} \leftarrow \mathrm{Norm}(T_{l-1})\\ 3: & x \in \mathbb{R}^{(B,N,P)} \leftarrow \mathrm{Linear}^{x}(T'_{l-1})\\ 4: & z \in \mathbb{R}^{(B,N,P)} \leftarrow \mathrm{Linear}^{z}(T'_{l-1})\\ 5: & \text{// Process the input sequence}\\ 6: & x' \in \mathbb{R}^{(B,N,P)} \leftarrow \mathrm{SiLU}(\mathrm{Conv1d}(x))\\ 7: & \bar A,\ \bar B,\ C \leftarrow \mathrm{ParametersFunction}(x')\\ 8: & y \in \mathbb{R}^{(B,N,P)} \leftarrow \mathrm{SSM}(\bar A,\bar B,C)(x')\\ 9: & \text{// Obtain gated } y\\ 10: & y' \in \mathbb{R}^{(B,N,P)} \leftarrow y \odot \mathrm{SiLU}(z)\\ 11: & \text{// Residual connection}\\ 12: & T_{l} \in \mathbb{R}^{(B,N,C)} \leftarrow \mathrm{Linear}^{T}(y') + T_{l-1}\\ \textbf{Return:} & T_{l} \end{array} \end{aligned} \]