Introduction

我们熟知,深度学习架构的发展,最开始的代表性结构是循环神经网络/RNN。

RNN 占用内存少,但是它梯度消失问题十分严重,而且难以并行训练。

接下来出现了 Transformer 架构,它效果极佳,但是由于上下文依赖的问题,它的时间复杂度达到了 \(\Theta(n^2)\),这严重限制了他在长上下文的表现。

一种新的模型,即我们将要看到的 Receptance Weighted Key Value (RWKV)模型,则尝试合并了二者的优点。

RWKV

Overview

RWKV 架构提出了 4 个基本向量,他们内嵌于所有的块之中,即:

  1. R: The Receptance vector acts as the receiver of past information. 充当过去信息的“接收器” 。在公式里,它通常通过 \(sigmoid\) 函数变成一个门控信号。

  2. W: The Weight signifies the positional weight decay vector, a trainable parameter within the model. 这是一个可训练的参数,代表位置权重衰减向量。它决定了过去的信息随时间“遗忘”的速度。

  3. K: The Key vector performs a role analogous to K in traditional attention mechanisms. 其角色类似于传统 Transformer 注意力机制中的 \(K\)

  4. V : The Value vector functions similarly to V in conventional attention processes. 同样类似于传统注意力机制中的 \(V\)

Architecture

RWKV 模型由多个堆叠的残差块 (Stacked Residual Blocks) 组成 。每个残差块包含两个子模块:

  1. 时间混合 (Time-mixing) 子块:负责处理序列在时间轴上的依赖关系,利用上述四个要素来捕捉过去的信息

  2. 通道混合 (Channel-mixing) 子块:负责在单个时间步内处理不同通道特征的交互。

RWKV Block 很像于注意力机制。

Token Shift

在进入这四个向量的线性变换之前,RWKV 引入了一个非常简单的动作:它将当前步的输入 \(x_t\) 与上一步的输入 \(x_{t-1}\) 进行线性插值。

这种设计让模型在计算每一层时都能“瞥一眼”过去,从而实现标记位移。 具体的计算公式如下:

\[\begin{align} r_t &= W_r \cdot (\mu_r \odot x_t + (1 - \mu_r) \odot x_{t-1})\\ k_t &= W_r \cdot (\mu_k \odot x_t + (1 - \mu_k) \odot x_{t-1})\\ v_t &= W_r \cdot (\mu_v \odot x_t + (1 - \mu_v) \odot x_{t-1}) \end{align} \]

Process

我们完整梳理一下一个 RWKV Block 的流程。

Time mixing channel

  1. 归一化与位移(LayerNorm & Token Shift)

    • 输入向量 \(x_t\) 首先经过 LayerNorm。

    • 接着进行 Token Shift:将当前时间步的输入 \(x_t\) 与上一个时间步保存下来的输入 \(x_{t-1}\) 进行线性插值。这就像是让模型同时观察“现在”和“刚刚”发生的事情。

  2. 生成 R, K, V

    • 通过线性变换生成三个向量:Receptance (\(r_t\))、Key (\(k_t\)) 和 Value (\(v_t\)),公式就是上面给出的
  3. WKV 算子

    • 这是 RWKV 最特殊的地方。它通过一个复杂的加权公式,计算所有过去时间步的 \(v\) 的加权平均值。

    • 权重由两部分决定:当前步的 \(k_t\) 和随时间衰减的权重 \(w\)(Time Decay)

    • 它的数学表达如下: \[ wkv_t = \frac{\sum_{i=1}^{t-1} e^{-(t-1-i)w+k_i} \odot v_i + e^{u+k_t} \odot v_t}{\sum_{i=1}^{t-1} e^{-(t-1-i)w+k_i} + e^{u+k_t}} \]

  4. 输出门控(Output Gating)

    • 将 Receptance (\(r_t\)) 经过 \(sigmoid\) 激活后,作为“门控信号”去乘以 WKV 的结果。

    • 最后通过一个线性变换 \(W_o\) 输出,并与原始输入 \(x_t\) 进行残差相加。

Channel-mixing

  1. 再次位移与投影

    • 同样先做 LayerNorm 和 Token Shift。

    • 生成用于通道混合的 \(r'_t\)\(k'_t\)

  2. 平方 ReLU 激活

    • \(k'_t\) 进行线性变换后,使用 Squared ReLU(即 \(V' = W'_v \cdot \max(k'_t, 0)^2\))进行激活。
  3. 最终门控输出

    • 再次利用 \(sigmoid(r'_t)\) 作为门控,决定保留多少计算后的特征。结果与第一阶段的输出进行残差相加。

对比时间混合和通道混合,我们容易发现:

  • WKV 算子起到了“上下文”的作用

  • 整体仍然保持“线性”的特点

Reflections

更概括性地说,RWKV 其实就是 Transformer 和 RNN 的结合:

  • \(r\) 向量起到了 gating 的作用

  • \(wkv\) 算子代替了昂贵的注意力计算,转而采用类似于线性递推的方式。\(wkv_t\) 其实就蕴含了从时间步 \(1\to t-1\) 的所有信息与 \(v_t\) 的融合

简单记录一下我做下来的一些小坑。

Stage 2

正确的端口

事实上,handle_data_packet(self, packet, in_port) 注释说明了是 :param in_port: the port from which the packet arrived.,但这个其实没有作用。我们关心的是我们当前这个 router 要把收到的 packet 向哪个端口转发,所以其实是:

1
2
3
4
dst = packet.dst
...
dst_port = self.table[dst].port
latency = self.table[dst].latency

Stage 4

注意 latency 计算完整

我们在接收到邻居的广播后,注意他发送过来的 latency 只是他到目标 host 的 latency,而我们如果要通过它中转,还需要加上我们到它的 latency:

1
2
3
4
5
latency = route_latency + self.ports.get_latency(port)
new_entry = TableEntry(dst=route_dst, port=port,
latency=latency,
expire_time=api.current_time()+self.ROUTE_TTL)

Stage 5

字典删除问题

由于 self.table 是一个字典类型,而直接在遍历中删除字典元素会引发错误:

1
RuntimeError: dictionary changed size during iteration

比较推荐的做法是先收集一个应当删除的 list:

1
2
3
4
5
6
7
8
9
10
11
12
expire_list = []
for dst, entry in self.table.items():
if entry.expire_time <= api.current_time():
expire_list.append(dst)

for dst in expire_list:
if self.POISON_EXPIRED:
new_entry = TableEntry(dst=dst, port=self.table[dst].port,
latency=INFINITY, expire_time=api.current_time()+self.ROUTE_TTL)
self.table[dst] = new_entry
else:
self.table.pop(dst)

Stage 6&7

关于处理 Split Horizon 和 Poison Reverse 有一个更聪明的写法,我们在遇到 self.SPLIT_HORIZON 为 True 的时候选择直接 continue,这样就避免了大量繁琐的 if 语句。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
for port in ports_update:
for dst in self.table.keys():
latency = self.table[dst].latency

if port == self.table[dst].port:
if self.SPLIT_HORIZON:
# Split Horizon: don't send anything at all
continue
elif self.POISON_REVERSE:
# Poison Reverse: send it as INFINITY
latency = INFINITY

if latency > INFINITY:
latency = INFINITY

route = RoutePacket(destination=dst, latency=latency)
if force or self.history[port].get(dst, None) != latency:
self.send_route(port, dst, latency=latency)
self.history[port][dst] = latency
else:
pass

Stage 10A

题干提示了要使用一个 self.history 的数据结构来记录每个端口向每个目的地上一次转发的 packet 的信息。我一开始使用了一个元组 (port, dst),然后使用 self.key: {(port, dst): RoutePacket} 这样的键值对来存储。

但这样有几个问题:

  1. 写起来很繁琐

  2. 有冗余的信息,我们不需要保存完整的 packet,因为我们的 key 已经有 dst 了(RoutePacket 本身有 latency 和 dst 两个信息)

综上,我们采用一个更良好的写法,即一个二重字典。

1
2
3
4
5
6
7
8
9
10
def __init__(self, ...):
from collections import defaultdict
self.history = defaultdict(dict)

# in send_routes
if force or self.history[port].get(dst, None) != latency:
self.send_route(port, dst, latency=latency)
self.history[port][dst] = latency
else:
pass

defaultdict

对于一个普通的 Python 字典, 如果你访问一个不存在的键,它会直接报错 KeyError。 但 defaultdict(dict) 自带默认值:

当尝试访问 self.history[port] 时,如果这个 port 还没出现过,它不会报错,而是自动创建一个空的字典 {} 填在那里。

配套的,使用 dict.get(key, None) 可以安全的取出对应的值。

Overview

Demultiplexing with Ports

IP 协议只能看到 IP。如果一台主机上同时有两个应用在和一个 server 通信,那么二者在 IP 层面是完全一样的(有相同的 source/destination IP)

我们该如何区分这两个应用,将数据正确的交给他们?

解决方案就是 port,每个应用走不同的 port 出去。

Because the transport layer is implemented in the operating system, these ports (sometimes called logical ports) are the attachment point where the application connects to the operating system’s network stack.

TCP

What is Routing?

End Host and Router

我们先说明 End Host 和 Router 的区别。

前者是网络中一条线路的端点,他只负责接受、发送自己的信息,而不负责“传送”任务。 后者则相反,仅仅起到中介的作用。

Packets

我们之前说过,Packets 是 Internet Layer 传送的基本单元。

为了正确地在互联网中传送数据,我们显然需要指明 Source 和 Destination。 他们需要一个唯一的标识符。

这个标识符是怎么来的,姑且先不考虑。在有了这个标识符之后,我们便在 packet 的头部加上 Source Address 和 Destination Address 两个元素。

Routing States

What is Rouing State

Routing States 最简单的说,是一种策略。即当 router 接收到一个 packet 的时候,根据这些策略,便能成功地把信息传送到目的地。

其中,最重要的、也是最泛用的策略被称作转发表/Forward Table

Forwarding Table

理论上而言,我们可以建立一张表:这张表直白的告诉我们,如果我想把一个 packet 送到一个地方,我的 next hop 应该是什么。

在现实中,next hop 往往不是指一个抽象的地点,而是指物理上的一个 port

Routing vs. Forwarding

我们现在区分两个概念:routing 本身指的是填充 forwarding table 的过程,而 forwarding 指的才是转发 packet 的过程。

因此,forwarding 本身并不需要知道 forwarding table 是如何计算的,故它是一个局部的动作。

相反,routing 本身要求我们必须对于整个网络结构有一个认知,因此这是一个全局的过程。

Distance Vector Protocols

Distance Vector Protocols 的主要内容如下:

For each destination:

  • If you hear an advertisement for that destination, update the table and reset the TTL if:
    • The destination isn’t in the table.
    • The advertised cost, plus the link cost to the neighbor, is better than the best-known cost.
    • The advertisement is from the current next-hop. Includes poison advertisements.
  • Advertise to all your neighbors when the table updates, and periodically (advertisement interval).
    • But don’t advertise back to the next-hop.
    • …Or, advertise poison back to the next-hop.
    • Any cost greater than or equal to 16 is advertised as infinity.
  • If a table entry expires, make the entry poison and advertise it.

我们分条来看。

Update Rules

我们如果从毗邻的节点接收到一条 (from, to, cost, TTL) 的信息,我们进行如下的判定来决定要不要更新我们的 forwarding table:

  1. 最明显的,如果我根本没有到 to 的路径,那我肯定要加上这条
  2. 这条路径比我现在手上有的要更快,我肯定选择它更好
  3. 这条路径是从我当前的最优的 next hop 发来的,这说明 next hop 有更新,它到 to 的距离发生了变化,有可能不再是最优解。为了保证正确性,我现在先更新这条,等后面别人广播的时候自然会更新成正确答案。
  4. 或者,如果是 poison 的,也就是 “我无法到达某个目的地” 的信息,那么我收到的 Cost 就是无穷大,类似 3,我也要更新。

我们通过不停广播自己的信息来使得这个网络达到收敛态/converge

  1. 首先,我们必须定时广播信息,告诉我的邻居们我还在线,以及我到某个目的地的最小 cost 是多少
  2. 但是注意,我们不能反向告诉 next hop,否则,如果 next hop 失效(不再能够到达目的地),它会误以为通过我还能抵达目的地,造成死循环(这被称作 Split Horizon 原则)
  3. 或者,我们定时广播时,反向告诉 next hop,我不可以到达 A(即,我是从 next hop 来的,我不能把 packet 再返回给你)。这被称作 Poison Reverse,往往是比 Split Horizon 更有效的广播方式
  4. 如果我们已经有一个死循环了,我们该怎么办?方法很简单,给定一个上限,如果 cost 超过了这个上限,我们就认为他是无穷大/不可达的。

Overview

Link-state protocols in one sentence: Every router learns the full network graph, and then runs shortest-paths on the graph to populate the forwarding table.

直白地说,Link-state 协议通过让所有路由器都有全局信息,使得他们并行计算(通过 Bellman-Ford/Dijkstra 等最短路算法)出自己的next hop 但这显然引起了一个问题,如果 A 计算的最短路径和 B 计算出的路径不一致怎么办呢?

因此,我们做出以下 4 条约定:

  1. All routers have to agree on the network topology. Suppose a link failed, but only one router knows about it. Then different routers are computing paths on totally different graphs, and might produce inconsistent results.

  2. All routers are finding least-cost paths through the path. If one router preferred more expensive paths for some reason, we would get inconsistent results.

  3. All costs are positive. Negative costs could produce negative-weight cycles.

  4. All routers use the same tiebreaking rules. If we assumed shortest paths are unique, then the previous two conditions are sufficient to ensure everybody picks the same path. This condition additionally ensures that if there are multiple paths tied as the shortest, everyone chooses the same one.

Learning About Graph Topology

  1. To discover neighbors, every router sends a hello message to all of its neighbors. For example, in this network, R2 sends to both of its neighbors: “Hello, I’m R2.” Now, R1 knows that it’s connected to R2, and R3 also knows that it’s connected to R2. Similarly, R1 says hello to R2, so now R2 knows about R1. Likewise, R3 says hello to R2, so R2 also knows about R3.

    As a result, everybody now knows who their immediate neighbors are. Note that R1 does not know about R3, because R1 and R3 are not neighbors.

  1. Now that we know about our neighbors, we should announce that fact to everybody. To make a global announcement, we send the announcement to all of our neighbors. Also, if we ever receive an announcement, we should send it to all of our neighbors as well. This ensures that every message gets propagated throughout the network. This is known as flooding information across the network. If any information changes (e.g. a neighbor disappears), we should flood that information as well.
  1. avoid infinite flooding: When we see a message for the first time, send that message to all neighbors, and write down that we’ve seen that message. (We have to write down this message anyway, since we’re trying to use this information to build up the network graph.) Then, if we ever see that same message again, don’t send it a second time.

概览

我们将网络架构自顶向下地分成 5 层。每一层只需要做好自己的工作就可以了。

由于历史原因,最顶层被叫做 "Layer 7"

Layers of the Internet

Physical Layer

In the Internet, we’re looking for a way to signal bits (1s and 0s) across space.

The technology could be voltages on an electrical wire, wireless radio waves, light pulses along optical fiber cables, among others.

Physical Layer 数据的基本单位是bit

Link Layer cares about how to send the data using Physical Layer.

Link Layer 只负责相邻节点之间的通信,其单位是frame/帧

Internet Layer

Internet Layer 是路由器工作的最高层,负责决定 next hop,即把数据转发给谁。其基本单位是packet/包

IP 协议就是在这一层工作的。

Transport Layer

Transport Layer 负责决定数据传输的可靠性,也就是 TCP/UDP 协议,以及“端到端通信”。

也就是说,他通过端口区分应用的数据的来/去。

其基本单位是段/segment

Application Layer

只负责考虑如何使用网络。它的基本数据单位被称作 message

Headers

每一层都会把从更低层拿来的数据加上对应的 Header。

Header 是只给这一层看的信息(同时,每一层也只能看自己层级的 header),它指明了信息传输的各种附加信息(采取的协议、收件人/发件人等等)。

在传输过程中,Header 不断被 peel off,然后又被加上新的。

由于 Header 的存在,每一层相当于仅仅和自己同层的 peer 进行通信,因此也要求相同层级采用的协议必须相同。

Resource Sharing

网络的总容量是有限度的,因此如何分配资源是一个重要的问题。

我们先考虑两个问题:

  1. 要保证网络正常工作,我们至少要多大的容量?

  2. 我们如何分配我们的容量?

第一个问题的解决方案被称作 statistical multiplexing,其原理非常直观:

第二个问题有两种解决办法,一种被称作 best effort,通俗而言就是所有人都只管发送信息,并且 "hope for the best",其对应的策略被称作 packet switching

另一种被称作 reservation,也就是在通信前会先在网络中预留出容量,其对应的策略被称作 Circuit Switching

The bandwidth of a link tells us how many bits we can send on the link per unit time. Intuitively, this is the speed of the link. If you think of a link as a pipe carrying water, the bandwidth is the width of the pipe. A wider pipe lets us feed more water into the pipe per second. We usually measure bandwidth in bits per second (e.g. 5 Gbps = 5 billion bits per second).

The propagation delay of a link tells us how long it takes for a bit to travel along the link. In the pipe analogy, this is the length of the link. A shorter pipe means that water spends less time in the pipe before arriving at the other end. Propagation delay is measured in time (e.g. nanoseconds, milliseconds).

If we multiply the bandwidth and the propagation delay, we get the bandwidth-delay product (BDP). Intuitively, this is the capacity of the link, or the number of bits that exist on the link at any given instant. In the pipe analogy, if we fill up the pipe and freeze time, the capacity of the pipe is how much water is in the pipe in that instant.

注意,bandwidth/带宽是指“一秒内发射的 bit 数量”,而 bit 必须是一个一个发射出去的。所以一个 bit 的发射用时是 \(\frac{1}{\text{bandwidth}}\)

注意区分 delay(或者用全称,propagation delay) 和这个发射用时 (也就是上文的 transmission delay) 的区别。前者是 bit 在物理链路里面传播(propagation)的时间,后者是终端发送一个 packet 所有 bit 的时间。这从我们下文的例子就可以看出来。

Timing Diagram

Suppose we have a link with bandwidth 1 Mbps = 1 million bits per second, and propagation delay of 1 ms = 0.001 seconds.

We want to send a 100 byte = 800 bit packet along this link. How long does it take to send this packet, from the time the first bit is sent, to the time the last bit is received?

To answer this question, we can draw a timing diagram. The left bar is the sender, and the right bar is the recipient. Time starts at 0 and increases as we move down the diagram.

Let’s focus on the first bit. We can put 1,000,000 bits on the link per second (bandwidth), so it takes 1/1,000,000 = 0.000001 seconds to put a single bit on the link. At time 0.000001 seconds, the link has a single bit on it, at the sender end.

It then takes 0.001 seconds for this bit to travel across the link (propagation delay), so at time 0.000001 + 0.001 seconds, the very first bit arrives at the recipient.

Now let’s think about the last bit. From before, it takes 0.000001 to put a bit on the link. We have 800 bits to send, so the last bit is placed on the link at time \(800*0.000001=0.0008\) seconds.

It then takes 0.001 seconds for the last bit to travel across the link, so at time 0.0008 + 0.001 seconds, the very last bit arrives at the recipient. This is the time when we can say the packet has arrived at the recipient.

Packet Delay

从上文例子可以看出,一个完整的 packet 发送的延迟就是 \(\text{transmission delay + propagation delay}\)

Pipe Diagram

To draw the link, we can imagine the link is a pipe (similar to the water analogy) and draw the pipe as a rectangle, where the width is the propagation delay, and the height is the bandwidth. The area of the pipe is the capacity of the link.

Pipe diagrams can be useful for comparing different links. Let’s look at the exact same packets traveling through three different links.

In the long term, we have enough capacity to send all the outgoing packets, but at this very instant in time, we have two packets arriving simultaneously, and we can only send out one. This is called transient overload, and it’s extremely common at switches in the Internet.

To cope with transient overload, the switch maintains a queue of packets. If two packets arrive simultaneously, the switch queues one of them and sends out the other one.

At any given time, the switch could choose to send a packet from one of the incoming links, or send a packet from the queue. This choice is determined by a packet scheduling algorithm, and there are lots of different designs that we’ll look at.

Now that we have a notion of queuing, we need to go back and update our packet delay formula. Now, packet delay is the sum of transmission delay, propagation delay, and queuing delay.

2026 就这么突然来了。

我已经很久没有写什么非技术性的东西了,只能感叹这一年过的比我体感要快得多。

2025 的一月仿佛还历历在目。在鼓楼图书馆里看 FSF,百无聊赖地复习着 CPL 和微积分;南京清早熙熙攘攘的人声,冬日清朗的空气,教超的蛋糕和布丁仿佛还是几天前的事情。

是因为长大一点了吗,对于过去的事情记得更加清楚,但时间也感觉流动的更快了。

去年寒假看了一半的 CSAPP,到现在还没看完。学了一点皮毛的数学分析,只有在看花书的时候派上了点用场。CS61A 倒是学完了,但是感觉这门课属于会的人不需要听,不会的人一开始也难解其妙。

先争取活过这个无比劳碌的期末周吧。

寒假里最重要的还是在学校里老老实实搬砖,争取今年能把第一篇文章投中(能中 TPAMI 什么的就最好了)。

然后完成一些一直想学,但是也没时间学的东西吧:

  1. UCB CS186 Database System:正好学学 Java,万一将来不去做 AI 去做互联网码农也能用得上。 其实我一直对软工系的东西很感兴趣,但苦于课程安排一直没什么机会。

  2. UCB CS168/Stanford CS144 Computer Networking:其实早在一年前就听完了 USTC 的计网,但是当时也就听个响,现在脑海里也就一点粗略的印象。 看网上的建议,听 UCB 的课,然后把两个的实验都做了。 还是得动手写点代码:(

  3. 把拖了一年的 CSAPP 听完,然后把实验做了。其实我很想做计科的 ICSPA,但是不知道有没有时间,估计要做的话只有过年的时间能砸进去了

  4. 操作系统。这个还没决定是 MIT 的还是 UCB 的,先把上面的七七八八做完再考虑这个吧。

  5. 学学 JavaScript

  6. 把《统计学习方法》之前跳过的部分读完。这学期学了凸优化,应该算是补上了 AI 需要的数学中的最后一块了。

这个学期都没写什么代码。事实上,整个一年都没写什么代码,但是理论貌似学的也不怎么样。

还是希望能提升一下自己的动手能力,亲自做一个大一点的工程吧。

再见,2025;干杯,2026!

Mamba 是一种 State Space Model/SSM,即状态空间模型。这是用以处理序列问题的一种非常优秀的模型。

其本质可以理解为“隐空间”/“状态空间”中的递推。

问题

给定时间信号输入序列 \(x(t)\),以及输出序列 \(y(t)\),尝试给出预测模型。 我们考虑一个隐状态空间 \(h(t)\in \mathbb{R}^N\),它由 \(x(t)\) 而来,控制了 \(y(t)\) 的输出。 从简化的角度考虑,给出线性模型:

\[\begin{align} h'(t) &= \mathbf{A}h(t)+\mathbf{B}x(t) \\ y(t) &= \mathbf{C}h'(t) \end{align} \]

其中,\(\mathbf{A}\)\(\mathbf{B}\)\(\mathbf{C}\) 三个矩阵分别代表 evolution parameter 和 projection parameter。第一个的作用是将隐空间中的状态逐步演化的最终态,后二者的作用是进行隐空间和实际输入/输出空间的映射。

从连续时间到离散时间

上面的模型显然是一个连续时间模型,也就意味着是一个微分方程。

但是在实际工程和应用中,我们不可能去解一个微分方程,因此我们考虑去将他离散化。其中,最重要的一种离散化方式被称作 "Zero-Order Hold/ZOH" 规则。

ZOH

ZOH 的核心假设非常易于理解:

Zero-Order Hold

在每个采样区间中,输入保持常数:

\[ u(t)=u_k\quad t\in [t_k,t_{k+1}) \]

其中,采样步长 \(\Delta>0\)\(t_k=k\Delta\)

我们希望得到离散递推公式: \[ x_{k+1}=\bar{A}x_k+\bar{B}u_k \]

所以,也就是去解这个微分方程(和普通的微分方程一样,先不考虑是矩阵的情况): \[ \frac{\mathrm{d}x}{\mathrm{d}t}=Ax+Bu \]

注意,这里 \(x\) 是状态,\(u\) 是输入。

将系统重写为 \[ \dot{x}(t)-Ax(t)=Bu(t) \]

解之即有: \[ \bar{A}=e^{\Delta A},\bar{B}=\int_{t_k}^{t_{k+1}}e^{A(t_{k+1}-\tau)}B\mathrm{d}\tau \]

\(A\) 可逆,则 \(B\) 有闭式解 \(A^{-1}(e^{A\Delta}-I)B\)

矩阵指数

\(B=e^{A}\) 并不是 \(B_{ij}=\exp\{a_{ij}\}\),而是采用泰勒展开: \[ e^A=\sum_{i=0}^{\infin} \frac{A^k}{k!} \]

对于对角矩阵,退化为上面的直觉的形式: \[ A=\text{diag}(a_1,\dots,a_n) \Rightarrow e^A=\text{diag}(e^{a_1},\dots,e^{a_n}) \]

类似,对于可对角化矩阵,也有: \[ A=P \text{diag}(a_1,\dots,a_n) P^{-1} \Rightarrow e^A=P\text{diag}(e^{a_1},\dots,e^{a_n})P^{-1} \]

递推

基于 ZOH 假设,我们可以快速写出以下式子(\(M\) 是序列长度): \[ \begin{align} \bar{K} &= (C\bar{B}, C\bar{A}\bar{B}, C\bar{A}^2\bar{B},\dots, C\bar{A}^{M-1}\bar{B})\\ y &= x * \bar{K} \end{align} \]

一个 Mamba Block

一个 Mamba Block 做一次完整的我们上述的递推操作,也就是对于每个位置提取一次上下文理解。

和 Transformer 类似,我们也可以把多个 Mamba Block 叠起来,进行更深度的信息提取,也就是逐层抽象,越靠后的 Mamba Block 蕴含了越丰富的、越抽象的信息。

\[ x^{(0)}\xrightarrow{\text{Mamba Block 1}} x^{(1)}\xrightarrow{\text{Mamba Block 2}}\dots \xrightarrow{\text{Mamba Block M}} x^{(M)} \]

具体实现

Algorithm 1:Parameters Function

根据我们上文的描述,从给定的序列 \(\{x\}\) 中生成 \(\bar{A}, \bar{B}, C\)

\[ \begin{aligned} \textbf{Algorithm 1: Parameters Function}\\ \textbf{Require: } x' \in \mathbb{R}^{(B,N,P)}\\ \textbf{Ensure: } \bar A \in \mathbb{R}^{(B,N,P,K)},\ \bar B \in \mathbb{R}^{(B,N,P,K)},\ C \in \mathbb{R}^{(B,N,K)}\\[4pt] \begin{array}{ll} 1: & B \in \mathbb{R}^{(B,N,K)} \leftarrow \mathrm{Linear}^{B}(x')\\ 2: & C \in \mathbb{R}^{(B,N,K)} \leftarrow \mathrm{Linear}^{C}(x')\\ 3: & \Delta \in \mathbb{R}^{(B,N,P)} \leftarrow \log\!\Big(1+\exp(\mathrm{Linear}^{\Delta}(x')+\mathrm{Parameter}^{\Delta})\Big)\\ 4: & \text{// } \mathrm{Parameter}^{A} \in \mathbb{R}^{(P,K)}\\ 5: & \bar A \in \mathbb{R}^{(B,N,P,K)} \leftarrow \Delta \otimes \mathrm{Parameter}^{A}\\ 6: & \bar B \in \mathbb{R}^{(B,N,P,K)} \leftarrow \Delta \otimes B\\ \textbf{Return:} & \bar A,\ \bar B,\ C \end{array} \end{aligned} \]

\[ \begin{aligned} \textbf{Algorithm 2: Mamba Block}\\ \textbf{Require: } T_{l-1} \in \mathbb{R}^{(B,N,C)}\\ \textbf{Ensure: } T_{l} \in \mathbb{R}^{(B,N,C)}\\[4pt] \begin{array}{ll} 1: & \text{// Apply layer normalization to } T_{l-1}\\ 2: & T'_{l-1} \in \mathbb{R}^{(B,N,C)} \leftarrow \mathrm{Norm}(T_{l-1})\\ 3: & x \in \mathbb{R}^{(B,N,P)} \leftarrow \mathrm{Linear}^{x}(T'_{l-1})\\ 4: & z \in \mathbb{R}^{(B,N,P)} \leftarrow \mathrm{Linear}^{z}(T'_{l-1})\\ 5: & \text{// Process the input sequence}\\ 6: & x' \in \mathbb{R}^{(B,N,P)} \leftarrow \mathrm{SiLU}(\mathrm{Conv1d}(x))\\ 7: & \bar A,\ \bar B,\ C \leftarrow \mathrm{ParametersFunction}(x')\\ 8: & y \in \mathbb{R}^{(B,N,P)} \leftarrow \mathrm{SSM}(\bar A,\bar B,C)(x')\\ 9: & \text{// Obtain gated } y\\ 10: & y' \in \mathbb{R}^{(B,N,P)} \leftarrow y \odot \mathrm{SiLU}(z)\\ 11: & \text{// Residual connection}\\ 12: & T_{l} \in \mathbb{R}^{(B,N,C)} \leftarrow \mathrm{Linear}^{T}(y') + T_{l-1}\\ \textbf{Return:} & T_{l} \end{array} \end{aligned} \]

EM 算法的目标

先考虑著名的三硬币问题:

假设三枚硬币 A, B, C,其正面朝上的概率分别为 \(\pi,p,q\)

先抛 A,若正面朝上抛 B,反之则抛 C。

只记录抛 B,C 的结果,得到一个 01 序列。

要求给出此模型的参数 \(\theta=(\pi,p,q)\)

抽象出问题实质:

给出观测数据 \(Y=[Y_1,\dots,Y_n]^T\) 以及隐变量 \(Z=[Z_1,\dots,Z_n]^T\),则给出似然估计: \[ P(Y|\theta) =\sum_Z P(Y|Z,\theta) P(Z|\theta) \]

此问题中可以化简为:

\[ P(Y|\theta) =\prod_{j=1}^n \left[\pi p^{y_j} (1-p)^{1-y_j} +(1-\pi) q^{y_j} (1-q)^{1-y_j} \right] \] 希望给出

\[ \hat{\theta}=\arg \max_{\theta} \log(P(Y|\theta)) \]

下面给出一个用语的约定:

  • 观测数据/不完全数据 \(Y=[Y_1,\dots,Y_n]^T\)

  • 隐变量/未观测数据/ \(\text{Latent Variables}\) \(Z=[Z_1,\dots,Z_n]^T\)

  • 完全数据:\(P(Y,Z|\theta)\)

  • 完全数据的对数似然函数:\(\log P(Y,Z|\theta)\)

EM 算法的导出

\[ \begin{align} L(\theta) &=P(Y|\theta)=\sum_Z P(Y|Z,\theta) P(Z|\theta) \\ &=\log \left( \sum_Z P(Y|Z,\theta) P(Z|\theta) \right) \end{align} \]

我们逐步迭代去求出最优化的 \(\theta\),设第 \(i\) 次求出的解是 \(\theta^{(i)}\),则: \[ \begin{align} L(\theta)-L(\theta^{(i)})&=\log \left( \sum_Z P(Y|Z,\theta) P(Z|\theta) \right)-\log P(Y|\theta^{(i)}) \\ &= \log \left(\sum_Z P(Z|Y,\theta^{(i)}) \dfrac{ P(Y|Z,\theta) P(Z|\theta) } {P(Z|Y,\theta^{(i)})} \right)-\log P(Y|\theta^{(i)})\\ &\geq \sum_Z P(Z|Y,\theta^{(i)})\log \left( \dfrac{ P(Y|Z,\theta) P(Z|\theta) } {P(Z|Y,\theta^{(i)})} \right)-\log P(Y|\theta^{(i)})\\ &= \sum_Z P(Z|Y,\theta^{(i)})\log \left( \dfrac{ P(Y|Z,\theta) P(Z|\theta) } {P(Z|Y,\theta^{(i)}) P(Y|\theta^{(i)})} \right)\\ \end{align} \]

不等号是Jensen不等式给出的

\(B(\theta,\theta^{(i)})\hat{=}L(\theta^{(i)})+\sum_Z P(Z|Y,\theta^{(i)})\log \left( \dfrac{ P(Y|Z,\theta) P(Z|\theta) }{P(Z|Y,\theta^{(i)}) P(Y|\theta^{(i)})} \right)\)

显然有 \(L(\theta)\geq B(\theta,\theta^{(i)})\),即 \(B\) 为原本函数的一个下界,且 \(\theta=\theta^{(i)}\) 时取等。 故可以考虑最大化 \(B\),从而使得 \(L\) 最大。

即: \[ \begin{align} \theta^{(i+1)} &= \arg \max _{\theta} B(\theta,\theta^{(i)}) \\ &=\arg \max_{\theta} \left( L(\theta^{(i)})+\sum_Z P(Z|Y,\theta^{(i)})\log \left( \dfrac{ P(Y|Z,\theta) P(Z|\theta) }{P(Z|Y,\theta^{(i)}) P(Y|\theta^{(i)})} \right) \right) \\ &=\arg \max_{\theta} \left( \sum_Z P(Z|Y,\theta^{(i)})\log \left( P(Y|Z,\theta) P(Z|\theta) \right) \right)\\ &=\arg \max_{\theta} \sum_Z P(Z|Y,\theta^{(i)})\log P(Y,Z|\theta) \end{align} \] 上面的等号都是去掉了常数项(上面的变量只有 \(\theta\)\(\theta^{(i)}\)固定) 令 \(Q(\theta,\theta^{(i)})=\sum_Z P(Z|Y,\theta^{(i)})\log P(Y,Z|\theta)\),则 Q 函数即为:

完全似然函数 \(\log P(Y,Z|\theta)\) 关于在给定观测数据 \(Y\) 和当前参数估计\(\theta^{(i)}\) 下,对于未观测数据 \(Z\) 的条件概率分布 \(P(Z|Y,\theta^{(i)})\) 的期望。

EM 算法实质上就是找一个凸函数 \(B/Q\),让它在 \(\theta^{(i)}\) 处严格等于对数似然函数 \(\log P(Y,Z|\theta)\),但始终在似然函数下方。我们通过不断增大 B 函数,迫使似然函数上升。

关于 Q 函数更直观的认知,见 Qwen

EM 算法的流程

  1. 初始化 \(\theta^{(0)}\)
  2. E(Expectation): 求出 Q 函数
  3. M(Maximization): 最大化 Q 函数,求出下一轮的 \(\theta^{(i+1)}\)
  4. 重复 E-M 步,直到收敛为止

GMM

GMM,即 Gaussian Mixture Model,刻画了如下的问题:

有 K 个 Gaussian,分别服从 \(\theta_k=(\mu_k,\sigma_k)\);第 \(k\) 个模型被选择的概率是 \(\alpha_k\)

\(P(y)=\sum_{k=1}^K \alpha_k \phi(\mu_k,\sigma_k)\)

现在你既不知道各个Gaussian的具体参数,也不知道选择的概率 \(\alpha_k\),任务就是根据观测的序列 \(y=(y_1,\dots,y_N)\) 求出 \(\theta=(\alpha_1,\dots,\alpha_K;\theta_1,\dots,\theta_K)\)

使用 EM 算法完成这个任务,逐步分析:

明确隐变量

\[ \gamma_{jk}= \begin{cases} 1 & 第 j 个观测来自第 k 个模型\\ 0 & 反之 \end{cases} \]

有了观测数据 \(y_j\) 之后,那么完全数据就是

\[ (y_j,\gamma_{j1},\dots,\gamma_{jk}) \]

后面整体算作上文的 \(Z\)

写出似然函数:

$$ \[\begin{align} P(y,\gamma|\theta) &= \prod_{j=1}^N P(y_j,\gamma_{j1}\dots,\gamma_{jk})\\ &= \prod_{j=1}^N \prod_{k=1}^K [\alpha_k \phi(y_j|\theta_k)]^{\gamma_{jk}}\\ \end{align}\] $$

0%