为什么RNN(循环神经网络)存在梯度消失和梯度爆炸?
1️⃣ 原理分析
RNN前向传播的公式为:
- x t x_t xt是t时刻的输入
- s t s_t st是t时刻的记忆, s t = f ( U ⋅ x t + W ⋅ s t − 1 ) s_t=f(U\cdot x_t+W\cdot s_{t-1}) st=f(U⋅xt+W⋅st−1),f表示激活函数, s t − 1 s_{t-1} st−1表示t-1时刻的记忆
- o t o_t ot是t时刻的输出, o t = s o f t m a x ( V ⋅ s t ) o_t=softmax(V\cdot s_t) ot=softmax(V⋅st)
采用交叉熵作为损失函数:
L
=
∑
i
=
1
T
−
o
t
ˉ
l
o
g
o
t
L=\sum_{i=1}^{T}-\bar{o_{t}}logo_{t}
L=i=1∑T−otˉlogot
其中T代表时间步的长度,
o
ˉ
t
\bar o_{t}
oˉt代表ground truth,
o
t
o_t
ot代表预测的输出。
假设有三个时间步,
t
=
1
,
2
,
3
t=1,2,3
t=1,2,3。假设初始记忆
s
t
=
0
s_t=0
st=0,则
t
=
1
t=1
t=1时的记忆和输出为:
s
1
=
f
(
U
x
1
+
W
s
0
)
o
1
=
f
[
V
⋅
f
(
U
x
1
+
W
s
0
)
]
\begin{aligned}&s_1=f(Ux_1+Ws_0)\\&o_{1}=f[V\cdot f(Ux_{1}+Ws_{0})]\end{aligned}
s1=f(Ux1+Ws0)o1=f[V⋅f(Ux1+Ws0)]
t
=
2
t=2
t=2时的记忆和输出为:
s
2
=
f
(
U
x
2
+
W
s
1
)
o
2
=
f
[
V
⋅
f
(
U
x
2
+
W
s
1
)
]
=
f
[
V
⋅
f
(
U
x
2
+
W
f
(
U
x
1
+
W
s
0
)
)
]
\begin{aligned}&s_2=f(Ux_2+Ws_1)\\&o_{2}=f[V\cdot f(Ux_{2}+Ws_{1})]=f[V\cdot f(Ux_{2}+Wf(Ux_1+Ws_0))]\end{aligned}
s2=f(Ux2+Ws1)o2=f[V⋅f(Ux2+Ws1)]=f[V⋅f(Ux2+Wf(Ux1+Ws0))]
这样很晕,我来画个箭头:
可以发现
s
2
s_2
s2是
s
1
s_1
s1的函数
t
=
3
t=3
t=3时的记忆和输出为:
s
3
=
f
(
U
x
3
+
W
s
2
)
o
3
=
f
[
V
⋅
f
(
U
x
3
+
W
s
2
)
]
=
f
[
V
⋅
f
(
U
x
3
+
W
f
(
U
x
2
+
W
s
1
)
)
]
=
f
[
V
⋅
f
(
U
x
3
+
W
f
(
U
x
2
+
W
f
(
U
x
1
+
W
s
0
)
)
)
]
\begin{aligned}&s_3=f(Ux_3+Ws_2)\\&o_{3}=f[V\cdot f(Ux_{3}+Ws_{2})]=f[V\cdot f(Ux_{3}+Wf(Ux_2+Ws_1))]=f[V\cdot f(Ux_{3}+Wf(Ux_2+Wf(Ux_1+Ws_0)))] \end{aligned}
s3=f(Ux3+Ws2)o3=f[V⋅f(Ux3+Ws2)]=f[V⋅f(Ux3+Wf(Ux2+Ws1))]=f[V⋅f(Ux3+Wf(Ux2+Wf(Ux1+Ws0)))]
画个箭头:
可以发现
s
3
s_3
s3是
s
2
s_2
s2的函数,又
s
2
s_2
s2是
s
1
s_1
s1的函数,因此
s
3
s_3
s3包含
s
2
s_2
s2和
s
1
s_1
s1
然后我们来分析反向传播:BPTT(Back-Propagation Through Time,时间上的反向传播)是针对RNN的训练算法,它的核心依然是基于梯度下降的反向传播。对于RNN来说,主要参数包括U、W和V。
以t=3时举例子,求U,V,W的梯度:
∂
L
3
∂
V
=
∂
L
3
∂
o
3
∂
o
3
∂
V
3
◯
∂
L
3
∂
W
=
∂
L
3
∂
o
3
∂
o
3
∂
s
3
∂
s
3
∂
W
+
∂
L
3
∂
o
3
∂
o
3
∂
s
2
∂
s
2
∂
W
+
∂
L
3
∂
o
3
∂
o
3
∂
s
3
∂
s
3
∂
s
2
∂
s
2
∂
s
1
∂
s
1
∂
W
4
◯
∂
L
3
∂
U
=
∂
L
3
∂
o
3
∂
o
3
∂
s
3
∂
s
3
∂
U
+
∂
L
3
∂
o
3
∂
o
3
∂
s
2
∂
s
2
∂
U
+
∂
L
3
∂
o
3
∂
o
3
∂
s
3
∂
s
3
∂
s
2
∂
s
2
∂
s
1
∂
s
1
∂
U
5
◯
\begin{aligned} &\frac{\partial L_3}{\partial V} =\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial V}\textcircled{3} \\ &\frac{\partial L_3}{\partial W} =\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial W}+\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_2}\frac{\partial s_2}{\partial W}+\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial s_1}\frac{\partial s_1}{\partial W}\textcircled{4} \\ &\frac{\partial L_3}{\partial U} =\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial U}+\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_2}\frac{\partial s_2}{\partial U}+\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial s_1}\frac{\partial s_1}{\partial U}\textcircled{5} \end{aligned}
∂V∂L3=∂o3∂L3∂V∂o33◯∂W∂L3=∂o3∂L3∂s3∂o3∂W∂s3+∂o3∂L3∂s2∂o3∂W∂s2+∂o3∂L3∂s3∂o3∂s2∂s3∂s1∂s2∂W∂s14◯∂U∂L3=∂o3∂L3∂s3∂o3∂U∂s3+∂o3∂L3∂s2∂o3∂U∂s2+∂o3∂L3∂s3∂o3∂s2∂s3∂s1∂s2∂U∂s15◯
对于公式⑤可以简写成:
∂
L
3
∂
U
=
∑
k
=
0
3
∂
L
3
∂
o
3
∂
o
3
∂
s
3
∂
s
3
∂
s
k
∂
s
k
∂
U
\frac{\partial L_3}{\partial U}=\sum_{k=0}^3\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial s_k}\frac{\partial s_k}{\partial U}
∂U∂L3=k=0∑3∂o3∂L3∂s3∂o3∂sk∂s3∂U∂sk
由于 ∂ s 3 ∂ s k \frac{\partial s_3}{\partial s_k} ∂sk∂s3也需要链式法则,即 ∂ s 3 ∂ s 1 = ∂ s 3 ∂ s 2 ∂ s 2 ∂ s 1 \frac{\partial s_3}{\partial s_1}=\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial s_1} ∂s1∂s3=∂s2∂s3∂s1∂s2。因此公式可以进一步修改为:
∂ L 3 ∂ U = ∑ k = 1 3 ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ∂ s 3 ∂ s k ∂ s k ∂ U = ∑ k = 1 3 ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ( ∏ j = k + 1 3 ∂ s j ∂ s j − 1 ) ∂ s k ∂ U 6 ◯ \frac{\partial L_3}{\partial U}=\sum_{k=1}^3\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial s_k}\frac{\partial s_k}{\partial U}=\sum_{k=1}^3\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}(\prod_{j=k+1}^3\frac{\partial s_j}{\partial s_{j-1}})\frac{\partial s_k}{\partial U}\textcircled{6} ∂U∂L3=k=1∑3∂o3∂L3∂s3∂o3∂sk∂s3∂U∂sk=k=1∑3∂o3∂L3∂s3∂o3(j=k+1∏3∂sj−1∂sj)∂U∂sk6◯
同理,对公式④也可以写为:
∂
L
3
∂
W
=
∑
k
=
1
3
∂
L
3
∂
o
3
∂
o
3
∂
s
3
(
∏
j
=
k
+
1
3
∂
s
j
∂
s
j
−
1
)
∂
s
k
∂
W
7
◯
\frac{\partial L_3}{\partial W}=\sum_{k=1}^3\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}(\prod_{j=k+1}^3\frac{\partial s_j}{\partial s_{j-1}})\frac{\partial s_k}{\partial W}\textcircled{7}
∂W∂L3=k=1∑3∂o3∂L3∂s3∂o3(j=k+1∏3∂sj−1∂sj)∂W∂sk7◯
观察③式,对与V的偏导不存在依赖关系。
观察④和⑤式,对W和U求偏导的时候,存在长期依赖关系。原因是前向传播的时候 s t s_t st会随着时间向前传播,而 s t s_t st是W、U的函数。
假设激活函数为tanh,将⑥⑦中累乘部分取出来:
∏
j
=
k
+
1
3
∂
s
j
∂
s
j
−
1
=
∏
j
=
k
+
1
3
t
a
n
h
′
W
\prod_{j=k+1}^3\frac{\partial s_j}{\partial s_{j-1}}=\prod_{j=k+1}^3tanh^{'}W
j=k+1∏3∂sj−1∂sj=j=k+1∏3tanh′W
例如:
s
3
=
f
(
U
x
3
+
W
s
2
)
s_3=f(Ux_3+Ws_2)
s3=f(Ux3+Ws2),
∂
s
3
∂
s
2
=
t
a
n
h
′
(
U
)
W
\frac{\partial s3}{\partial s_{2}}=tanh'(U) W
∂s2∂s3=tanh′(U)W
由上图可知,tanh的梯度最大为1,通常情况下会小于1,因此当t很大的时候,例如t=100时,⑥⑦中的累乘部分 ∏ j = k + 1 100 t a n h ′ W \prod_{j=k+1}^{100}tanh^{^{\prime}}W ∏j=k+1100tanh′W将趋于0,因此t=100时对于W和U的梯度将趋于0,导致梯度消失。
分析完tanh,再来分析一下W,如果W中的值太大,那么产生问题就是梯度爆炸
2️⃣ 总结
- RNN存在梯度消失的原因是:隐藏层的输出
s
t
s_t
st会向前传播,这样导致在反向传播求梯度时存在一个累乘项,这个累乘项由
激活函数的梯度
和参数W
组成,如果我们采用tanh作为激活函数,其梯度小于1,时间步越多,累乘项越趋近于0,导致梯度消失。 - RNN存在梯度爆炸的原因:参数W如果过大,则会导致累乘项逐渐变大,导致梯度爆炸
3️⃣ 参考
RNN梯度消失与梯度爆炸的原因 - Hideonbush的文章 - 知乎