Skip to content

第 14 章 序列模型:RNN 和 LSTM

前面的全连接网络和卷积神经网络主要处理固定长度输入。序列模型处理的是有顺序的数据,例如文本、语音、时间序列和用户行为记录。

序列数据的特点是:当前元素的意义往往依赖前后文。例如“我喜欢这部电影”和“我不喜欢这部电影”只差一个“不”,但语义完全不同。

本章把普通 RNN 和 LSTM 放在一起讲。普通 RNN 用来理解序列递推思想,LSTM 用来解决普通 RNN 难以学习长期依赖的问题。

14.1 序列建模问题

普通监督学习通常把一个样本表示成一个固定向量:

y^=f(x)

序列建模的输入是一串有顺序的元素:

x1,x2,,xT

其中 T 表示序列长度。

不同任务的输入和输出形式不同:

形式例子
多对一情感分类、用户行为预测
一对多图像生成文字描述
多对多等长词性标注、语音帧分类
多对多不等长机器翻译、摘要生成

序列模型的关键是利用顺序信息,而不是把所有元素当成互相独立的特征。

14.2 RNN 的基本思想

RNN 的全称是循环神经网络。它在处理序列时,每一步读取一个输入,同时保留一个隐藏状态。

隐藏状态可以理解为模型对“到目前为止已经看过的信息”的总结。

t 个时间步:

ht=ϕ(Wxxt+Whht1+b)

输出可以写成:

y^t=g(Wyht+c)

其中:

符号含义
xtt 个输入
htt 个隐藏状态
ht1上一个隐藏状态
y^tt 步输出

这个式子的意思是:当前隐藏状态不仅由当前输入决定,也由上一时刻的隐藏状态决定。因此 RNN 能把历史信息带到后面的时间步。

14.3 参数共享

RNN 在不同时间步使用同一组参数。

处理 x1x2x3 时,使用的都是同一组 WxWhWy。这叫参数共享。

参数共享有两个好处:

  • 可以处理不同长度的序列。
  • 参数量不会随着序列长度增加而线性增加。

如果每个时间步都使用不同参数,长度稍微变化模型就不好处理,参数数量也会非常大。

14.4 展开后的 RNN

RNN 可以沿时间展开理解:

text
x1 -> h1 -> y1
      |
x2 -> h2 -> y2
      |
x3 -> h3 -> y3

更准确地说,h2 依赖 h1h3 依赖 h2。历史信息就是沿着隐藏状态一步步传递。

这个结构适合表达顺序关系,但也带来问题:很早的信息要经过很多次传递才能影响后面的输出,序列很长时容易被冲淡。

14.5 RNN 的训练:BPTT

RNN 的训练叫通过时间反向传播,简称 BPTT。

它的做法是先把 RNN 按时间展开,得到一个很深的计算图,然后从损失函数开始,沿着时间方向反向传播梯度。

如果序列长度是 T,梯度可能要穿过很多个时间步。每一步都会受到权重和激活函数导数的影响,所以普通 RNN 容易出现两个问题:

问题表现常见处理
梯度消失很难学到较早时间步的信息LSTM、GRU、注意力
梯度爆炸损失剧烈震荡,训练不稳定梯度裁剪

普通 RNN 的主要局限就在这里:它能处理序列,但不擅长保留长期信息。

14.6 LSTM 为什么出现

LSTM 是长短期记忆网络。它仍然属于 RNN,但结构比普通 RNN 多了门控机制。

普通 RNN 只有一个隐藏状态 ht,既要负责记忆历史,又要负责当前输出。序列变长以后,早期信息容易被覆盖。

LSTM 增加了一个细胞状态 ct。可以把它理解成一条更稳定的记忆通道。模型通过门控机制决定:

  • 哪些旧信息要保留。
  • 哪些新信息要写入。
  • 哪些记忆要用于当前输出。

14.7 LSTM 的三个门

LSTM 主要有三个门:遗忘门、输入门、输出门。

作用
遗忘门决定上一时刻的记忆保留多少
输入门决定当前新信息写入多少
输出门决定当前隐藏状态输出多少

门的输出通常经过 sigmoid,所以每个位置的值在 0 到 1 之间。接近 1 表示多保留,接近 0 表示少保留。

遗忘门:

ft=σ(Wf[ht1,xt]+bf)

输入门:

it=σ(Wi[ht1,xt]+bi)

候选记忆:

c~t=tanh(Wc[ht1,xt]+bc)

输出门:

ot=σ(Wo[ht1,xt]+bo)

这里的 [ht1,xt] 表示把上一隐藏状态和当前输入拼接起来。

14.8 LSTM 的状态更新

LSTM 的细胞状态更新为:

ct=ftct1+itc~t

其中 表示逐元素相乘。

这个式子分成两部分:

  • ftct1:保留旧记忆。
  • itc~t:写入新记忆。

隐藏状态为:

ht=ottanh(ct)

ct 更偏长期记忆,ht 更偏当前输出。

14.9 LSTM 如何缓解长期依赖

普通 RNN 每一步都会用新的非线性变换更新隐藏状态,旧信息很容易被覆盖。

LSTM 的细胞状态使用加法路径更新:

ct=ftct1+itc~t

如果遗忘门接近 1,重要信息可以沿着 ct 传递很多时间步。这样梯度也更容易沿这条路径往前传播。

LSTM 不是彻底解决所有长序列问题,但比普通 RNN 更适合学习较长依赖。

14.10 GRU 和 Transformer 的位置

GRU 是 LSTM 的简化版本。它没有单独的细胞状态,参数更少,训练更快。很多任务中,GRU 和 LSTM 的效果接近。

Transformer 不再按时间步递推,而是用注意力机制直接建立不同位置之间的关系。它更容易并行,也更擅长建模长距离依赖,因此成为现在自然语言处理和大模型的主流结构。

三者可以这样理解:

模型核心思想局限
RNN用隐藏状态递推处理序列长期依赖弱,难并行
LSTM用门控和细胞状态保存长期信息仍然按时间步串行
Transformer用注意力直接连接不同位置注意力计算成本较高

Powered by VitePress