梯度消失与长短时记忆网络
1. RNN 的梯度消失回顾
RNN 反向传播时,梯度需要沿时间步传递:
每步都要乘以 。由于 最大值为 1,长序列下梯度指数衰减,早期时间步几乎得不到更新。
1.5 梯度爆炸:与梯度消失相反的危机
梯度消失是梯度太小,梯度爆炸则是梯度太大。当 的特征值大于 1 时,每步相乘后梯度指数级增长:
危害:参数更新幅度极大,loss 剧烈震荡甚至变成 NaN,训练直接崩溃。
解决方案:梯度裁剪(Gradient Clipping)
当梯度的范数超过阈值 时,等比例缩小:
类比:给油门加一个限速器——踩再猛也不会超速,但方向不变。
| 梯度消失 | 梯度爆炸 | |
|---|---|---|
| 原因 | ,梯度指数衰减 | ,梯度指数增长 |
| 症状 | 早期层参数不更新,模型学不到长程依赖 | loss 震荡或 NaN,训练崩溃 |
| 解决 | LSTM/GRU 门控机制 | 梯度裁剪 |
2. LSTM:用”门”控制记忆
类比:LSTM 就像一个有三道闸门的水库——遗忘门决定放掉多少旧水,输入门决定注入多少新水,输出门决定放出多少水供下游使用。水库里的水就是细胞状态 ,是跨时间步传递长期记忆的”主干道”。
2.1 LSTM 四个核心计算
细胞状态更新:
其中 1 表示逐元素乘法(Hadamard 积),2 为候选记忆。 核心数据3 三个“阀门”4 数学符号与参数5
2.2 为什么 LSTM 能缓解梯度消失?
细胞状态 的更新是加法而非乘法:
梯度通过加法路径传回时,关键梯度项为:
只要遗忘门 接近 1(网络选择”记住”),梯度就能无衰减地传回早期时间步,从而保持梯度流动。
3. 代码实现
import micropipawait micropip.install("numpy") # 仅适用于 Obsidian Code Emitter (Pyodide) 环境import numpy as np
def sigmoid(z): return 1 / (1 + np.exp(-z))
def lstm_step(x_t, h_prev, C_prev, params): Wf, Wi, WC, Wo, bf, bi, bC, bo = params concat = np.concatenate([h_prev, x_t]) # 拼接隐藏状态和输入
f = sigmoid(Wf @ concat + bf) # 遗忘门 i = sigmoid(Wi @ concat + bi) # 输入门 C_tilde = np.tanh(WC @ concat + bC) # 候选记忆 o = sigmoid(Wo @ concat + bo) # 输出门
C = f * C_prev + i * C_tilde # 更新细胞状态 h = o * np.tanh(C) # 更新隐藏状态 return h, C
# 初始化参数hidden, input_size = 16, 8concat_size = hidden + input_sizeparams = [ np.random.randn(hidden, concat_size) * 0.01, # Wf np.random.randn(hidden, concat_size) * 0.01, # Wi np.random.randn(hidden, concat_size) * 0.01, # WC np.random.randn(hidden, concat_size) * 0.01, # Wo np.zeros(hidden), np.zeros(hidden), # bf, bi np.zeros(hidden), np.zeros(hidden), # bC, bo]
h, C = np.zeros(hidden), np.zeros(hidden)for t in range(10): # 序列长度10 x_t = np.random.randn(input_size) h, C = lstm_step(x_t, h, C, params)
print("最终隐藏状态:", h.shape)print("最终细胞状态:", C.shape)4. LSTM vs GRU
GRU6 是 LSTM 的简化版,将遗忘门和输入门合并为更新门,参数更少,训练更快。
| 对比项 | LSTM | GRU |
|---|---|---|
| 门数量 | 3(遗忘、输入、输出) | 2(更新、重置) |
| 参数量 | 较多 | 较少(约 LSTM 的 3/4) |
| 性能 | 长序列略优 | 短序列相当甚至更好 |
| 计算速度 | 较慢 | 较快 |
实践中两者差异不大,优先尝试 GRU(更快),若效果不足再换 LSTM。
GRU 核心公式
更新门 同时控制”遗忘多少旧状态”和”写入多少新状态”,相当于 LSTM 遗忘门和输入门的合并。
相关笔记
Footnotes
-
(逐元素乘法):两个形状相同的向量/矩阵对应位置相乘,结果形状不变。例如 。区别于矩阵乘法 (改变形状)。 ↩
-
候选记忆 :由当前输入 和上一步隐藏状态 计算出的”备选新记忆”。之所以叫”候选”,是因为它不会直接写入细胞状态,而是由输入门 决定写入多少—— 接近 0 则忽略,接近 1 则完全采纳。 ↩
-
-
:当前时刻的输入(比如句子中的第 个单词)。
-
:上一步的短期记忆(隐藏状态)。前一个时刻传过来的局部信息。
-
:上一步的长期记忆(细胞状态)。也就是那条贯穿全局的“主干道”。
-
:当前步算出的短期记忆(准备传给下一个时刻,也是当前层的输出)。
-
:当前步更新后的长期记忆(准备传给下一个时刻)。
-
:候选记忆。也就是当前这一步_准备_写入长期记忆的新知识。
-
-
-
:遗忘门 (Forget Gate)。决定旧的长期记忆 有多少应该被扔掉。
-
:输入门 (Input Gate)。决定候选的新知识 有多少应该被写进长期记忆。
-
:输出门 (Output Gate)。决定更新后的长期记忆 中,有多少内容应该被提取出来,作为当前的短期记忆 输出。
-
-
-
:权重矩阵(网络需要学习的参数)。
-
:偏置项(网络需要学习的门槛参数)。
-
:拼接 (Concatenate)。把旧的短期记忆和新的输入绑在一起,拼成一个更长的大向量。
-
:Sigmoid 激活函数。把算出来的数值强行压缩到 之间,完美模拟“阀门”的开闭状态(0 = 关死,1 = 全开)。
-
:双曲正切函数。把数值压缩到 之间,用于生成包含正负方向的新信息。
-
:按元素相乘 (Element-wise multiplication)。就像两个矩阵对应位置的数字相乘,起到了“过滤”的作用。
-
-
GRU(门控循环单元):2014年提出的 LSTM 简化版。取消了独立的细胞状态 ,只保留隐藏状态 ,用更新门控制保留多少历史信息,用重置门控制如何融合新输入。参数更少,在许多任务上与 LSTM 效果相当。 ↩