RWKV 初探
Introduction
我们熟知,深度学习架构的发展,最开始的代表性结构是循环神经网络/RNN。
RNN 占用内存少,但是它梯度消失问题十分严重,而且难以并行训练。
接下来出现了 Transformer 架构,它效果极佳,但是由于上下文依赖的问题,它的时间复杂度达到了 \(\Theta(n^2)\),这严重限制了他在长上下文的表现。
一种新的模型,即我们将要看到的 Receptance Weighted Key Value (RWKV)模型,则尝试合并了二者的优点。
RWKV
Overview
RWKV 架构提出了 4 个基本向量,他们内嵌于所有的块之中,即:
R: The Receptance vector acts as the receiver of past information. 充当过去信息的“接收器” 。在公式里,它通常通过 \(sigmoid\) 函数变成一个门控信号。
W: The Weight signifies the positional weight decay vector, a trainable parameter within the model. 这是一个可训练的参数,代表位置权重衰减向量。它决定了过去的信息随时间“遗忘”的速度。
K: The Key vector performs a role analogous to K in traditional attention mechanisms. 其角色类似于传统 Transformer 注意力机制中的 \(K\)
V : The Value vector functions similarly to V in conventional attention processes. 同样类似于传统注意力机制中的 \(V\)
Architecture
RWKV 模型由多个堆叠的残差块 (Stacked Residual Blocks) 组成 。每个残差块包含两个子模块:
时间混合 (Time-mixing) 子块:负责处理序列在时间轴上的依赖关系,利用上述四个要素来捕捉过去的信息
通道混合 (Channel-mixing) 子块:负责在单个时间步内处理不同通道特征的交互。
RWKV Block 很像于注意力机制。
Token Shift
在进入这四个向量的线性变换之前,RWKV 引入了一个非常简单的动作:它将当前步的输入 \(x_t\) 与上一步的输入 \(x_{t-1}\) 进行线性插值。
这种设计让模型在计算每一层时都能“瞥一眼”过去,从而实现标记位移。 具体的计算公式如下:
\[\begin{align} r_t &= W_r \cdot (\mu_r \odot x_t + (1 - \mu_r) \odot x_{t-1})\\ k_t &= W_r \cdot (\mu_k \odot x_t + (1 - \mu_k) \odot x_{t-1})\\ v_t &= W_r \cdot (\mu_v \odot x_t + (1 - \mu_v) \odot x_{t-1}) \end{align} \]
Process
我们完整梳理一下一个 RWKV Block 的流程。
Time mixing channel
归一化与位移(LayerNorm & Token Shift)
输入向量 \(x_t\) 首先经过 LayerNorm。
接着进行 Token Shift:将当前时间步的输入 \(x_t\) 与上一个时间步保存下来的输入 \(x_{t-1}\) 进行线性插值。这就像是让模型同时观察“现在”和“刚刚”发生的事情。
生成 R, K, V
- 通过线性变换生成三个向量:Receptance (\(r_t\))、Key (\(k_t\)) 和 Value (\(v_t\)),公式就是上面给出的
WKV 算子
这是 RWKV 最特殊的地方。它通过一个复杂的加权公式,计算所有过去时间步的 \(v\) 的加权平均值。
权重由两部分决定:当前步的 \(k_t\) 和随时间衰减的权重 \(w\)(Time Decay)
它的数学表达如下: \[ wkv_t = \frac{\sum_{i=1}^{t-1} e^{-(t-1-i)w+k_i} \odot v_i + e^{u+k_t} \odot v_t}{\sum_{i=1}^{t-1} e^{-(t-1-i)w+k_i} + e^{u+k_t}} \]
输出门控(Output Gating)
将 Receptance (\(r_t\)) 经过 \(sigmoid\) 激活后,作为“门控信号”去乘以 WKV 的结果。
最后通过一个线性变换 \(W_o\) 输出,并与原始输入 \(x_t\) 进行残差相加。
Channel-mixing
再次位移与投影
同样先做 LayerNorm 和 Token Shift。
生成用于通道混合的 \(r'_t\) 和 \(k'_t\)
平方 ReLU 激活
- 对 \(k'_t\) 进行线性变换后,使用 Squared ReLU(即 \(V' = W'_v \cdot \max(k'_t, 0)^2\))进行激活。
最终门控输出
- 再次利用 \(sigmoid(r'_t)\) 作为门控,决定保留多少计算后的特征。结果与第一阶段的输出进行残差相加。
对比时间混合和通道混合,我们容易发现:
WKV 算子起到了“上下文”的作用
整体仍然保持“线性”的特点
Reflections
更概括性地说,RWKV 其实就是 Transformer 和 RNN 的结合:
\(r\) 向量起到了 gating 的作用
\(wkv\) 算子代替了昂贵的注意力计算,转而采用类似于线性递推的方式。\(wkv_t\) 其实就蕴含了从时间步 \(1\to t-1\) 的所有信息与 \(v_t\) 的融合