RWKV 初探

Introduction

我们熟知,深度学习架构的发展,最开始的代表性结构是循环神经网络/RNN。

RNN 占用内存少,但是它梯度消失问题十分严重,而且难以并行训练。

接下来出现了 Transformer 架构,它效果极佳,但是由于上下文依赖的问题,它的时间复杂度达到了 \(\Theta(n^2)\),这严重限制了他在长上下文的表现。

一种新的模型,即我们将要看到的 Receptance Weighted Key Value (RWKV)模型,则尝试合并了二者的优点。

RWKV

Overview

RWKV 架构提出了 4 个基本向量,他们内嵌于所有的块之中,即:

  1. R: The Receptance vector acts as the receiver of past information. 充当过去信息的“接收器” 。在公式里,它通常通过 \(sigmoid\) 函数变成一个门控信号。

  2. W: The Weight signifies the positional weight decay vector, a trainable parameter within the model. 这是一个可训练的参数,代表位置权重衰减向量。它决定了过去的信息随时间“遗忘”的速度。

  3. K: The Key vector performs a role analogous to K in traditional attention mechanisms. 其角色类似于传统 Transformer 注意力机制中的 \(K\)

  4. V : The Value vector functions similarly to V in conventional attention processes. 同样类似于传统注意力机制中的 \(V\)

Architecture

RWKV 模型由多个堆叠的残差块 (Stacked Residual Blocks) 组成 。每个残差块包含两个子模块:

  1. 时间混合 (Time-mixing) 子块:负责处理序列在时间轴上的依赖关系,利用上述四个要素来捕捉过去的信息

  2. 通道混合 (Channel-mixing) 子块:负责在单个时间步内处理不同通道特征的交互。

RWKV Block 很像于注意力机制。

Token Shift

在进入这四个向量的线性变换之前,RWKV 引入了一个非常简单的动作:它将当前步的输入 \(x_t\) 与上一步的输入 \(x_{t-1}\) 进行线性插值。

这种设计让模型在计算每一层时都能“瞥一眼”过去,从而实现标记位移。 具体的计算公式如下:

\[\begin{align} r_t &= W_r \cdot (\mu_r \odot x_t + (1 - \mu_r) \odot x_{t-1})\\ k_t &= W_r \cdot (\mu_k \odot x_t + (1 - \mu_k) \odot x_{t-1})\\ v_t &= W_r \cdot (\mu_v \odot x_t + (1 - \mu_v) \odot x_{t-1}) \end{align} \]

Process

我们完整梳理一下一个 RWKV Block 的流程。

Time mixing channel

  1. 归一化与位移(LayerNorm & Token Shift)

    • 输入向量 \(x_t\) 首先经过 LayerNorm。

    • 接着进行 Token Shift:将当前时间步的输入 \(x_t\) 与上一个时间步保存下来的输入 \(x_{t-1}\) 进行线性插值。这就像是让模型同时观察“现在”和“刚刚”发生的事情。

  2. 生成 R, K, V

    • 通过线性变换生成三个向量:Receptance (\(r_t\))、Key (\(k_t\)) 和 Value (\(v_t\)),公式就是上面给出的
  3. WKV 算子

    • 这是 RWKV 最特殊的地方。它通过一个复杂的加权公式,计算所有过去时间步的 \(v\) 的加权平均值。

    • 权重由两部分决定:当前步的 \(k_t\) 和随时间衰减的权重 \(w\)(Time Decay)

    • 它的数学表达如下: \[ wkv_t = \frac{\sum_{i=1}^{t-1} e^{-(t-1-i)w+k_i} \odot v_i + e^{u+k_t} \odot v_t}{\sum_{i=1}^{t-1} e^{-(t-1-i)w+k_i} + e^{u+k_t}} \]

  4. 输出门控(Output Gating)

    • 将 Receptance (\(r_t\)) 经过 \(sigmoid\) 激活后,作为“门控信号”去乘以 WKV 的结果。

    • 最后通过一个线性变换 \(W_o\) 输出,并与原始输入 \(x_t\) 进行残差相加。

Channel-mixing

  1. 再次位移与投影

    • 同样先做 LayerNorm 和 Token Shift。

    • 生成用于通道混合的 \(r'_t\)\(k'_t\)

  2. 平方 ReLU 激活

    • \(k'_t\) 进行线性变换后,使用 Squared ReLU(即 \(V' = W'_v \cdot \max(k'_t, 0)^2\))进行激活。
  3. 最终门控输出

    • 再次利用 \(sigmoid(r'_t)\) 作为门控,决定保留多少计算后的特征。结果与第一阶段的输出进行残差相加。

对比时间混合和通道混合,我们容易发现:

  • WKV 算子起到了“上下文”的作用

  • 整体仍然保持“线性”的特点

Reflections

更概括性地说,RWKV 其实就是 Transformer 和 RNN 的结合:

  • \(r\) 向量起到了 gating 的作用

  • \(wkv\) 算子代替了昂贵的注意力计算,转而采用类似于线性递推的方式。\(wkv_t\) 其实就蕴含了从时间步 \(1\to t-1\) 的所有信息与 \(v_t\) 的融合