图解RNN中的梯度消失与爆炸问题
图解RNN中的梯度消失与爆炸问题
经典的RNN结构如下图所示:
假设我们的时间序列只有三段,
S
0
S_{0}
S0 为给定值,神经元没有激活函数,则RNN最简单的前向传播过程如下:
S 1 = W x X 1 + W s S 0 + b 1 , O 1 = W 0 S 1 + b 2 S_{1} = W_{x} X_{1} + W_{s}S_{0} + b_{1},O_{1} = W_{0} S_{1} + b_{2} S1=WxX1+WsS0+b1,O1=W0S1+b2
S 2 = W x X 2 + W s S 1 + b 1 , O 2 = W 0 S 2 + b 2 S_{2} = W_{x} X_{2} + W_{s}S_{1} + b_{1},O_{2} = W_{0} S_{2} + b_{2} S2=WxX2+WsS1+b1,O2=W0S2+b2
S 3 = W x X 3 + W s S 2 + b 1 , O 3 = W 0 S 3 + b 2 S_{3} = W_{x} X_{3} + W_{s}S_{2} + b_{1},O_{3} = W_{0} S_{3} + b_{2} S3=WxX3+WsS2+b1,O3=W0S3+b2
假设在 t = 3 t=3 t=3时刻,损失函数为 L 3 = 1 2 ( Y 3 − O 3 ) 2 L_3 = \frac{1}{2}(Y_3 - O_3)^2 L3=21(Y3−O3)2 。则对于一次训练任务的损失函数为 L = ∑ t = 0 T L t L = \sum_{t=0}^{T} L_t L=∑t=0TLt ,即每一时刻损失值的累加。
使用随机梯度下降法训练RNN其实就是对 W x W_x Wx 、 W s W_s Ws 、 W o W_o Wo 以及 b 1 、 b 2 b_1 、 b_2 b1、b2 求偏导,并不断调整它们以使 L L L尽可能达到最小的过程。
现在假设我们我们的时间序列只有三段:t1,t2,t3。我们只对 t 3 t3 t3时刻的 W x W_x Wx、 W s W_s Ws、 W o W_o Wo 求偏导(其他时刻类似):
∂ L 3 ∂ W 0 = ∂ L 3 ∂ O 3 ∂ O 3 ∂ W o = ∂ L 3 ∂ O 3 S 3 \frac{\partial L_3}{\partial W_0} = \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial W_o} = \frac{\partial L_3}{\partial O_3} S_3 ∂W0∂L3=∂O3∂L3∂Wo∂O3=∂O3∂L3S3
∂ L 3 ∂ W x = ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ W x + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S 2 ∂ S 2 ∂ W x + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S 2 ∂ S 2 ∂ S 1 ∂ S 1 ∂ W x \frac{\partial L_3}{\partial W_x} = \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial W_x} + \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial S_2} \frac{\partial S_2}{\partial W_x} + \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial S_2} \frac{\partial S_2}{\partial S_1} \frac{\partial S_1}{\partial W_x} ∂Wx∂L3=∂O3∂L3∂S3∂O3∂Wx∂S3+∂O3∂L3∂S3∂O3∂S2∂S3∂Wx∂S2+∂O3∂L3∂S3∂O3∂S2∂S3∂S1∂S2∂Wx∂S1
∂ L 3 ∂ W x = ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ W x + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S 2 ∂ S 2 ∂ W x + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S 2 ∂ S 2 ∂ S 1 ∂ S 1 ∂ W x = ∂ L 3 ∂ O 3 W 0 ( X 3 + S 2 W s + S 1 W s 2 ) \frac{\partial L_3}{\partial W_x} = \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial W_x} + \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial S_2} \frac{\partial S_2}{\partial W_x} + \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial S_2} \frac{\partial S_2}{\partial S_1} \frac{\partial S_1}{\partial W_x} = \frac{\partial L_3}{\partial O_3} W_0 (X_3 + S_2 W_s + S_1 W_s^2) ∂Wx∂L3=∂O3∂L3∂S3∂O3∂Wx∂S3+∂O3∂L3∂S3∂O3∂S2∂S3∂Wx∂S2+∂O3∂L3∂S3∂O3∂S2∂S3∂S1∂S2∂Wx∂S1=∂O3∂L3W0(X3+S2Ws+S1Ws2)
∂ L 3 ∂ W s = ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ W s + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S 2 ∂ S 2 ∂ W s + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S 2 ∂ S 2 ∂ S 1 ∂ S 1 ∂ W s = ∂ L 3 ∂ O 3 W 0 ( S 2 + S 1 W s + S 0 W s 2 ) \frac{\partial L_3}{\partial W_s} = \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial W_s} + \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial S_2} \frac{\partial S_2}{\partial W_s} + \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial S_2} \frac{\partial S_2}{\partial S_1} \frac{\partial S_1}{\partial W_s} = \frac{\partial L_3}{\partial O_3} W_0 (S_2 + S_1 W_s + S_0 W_s^2) ∂Ws∂L3=∂O3∂L3∂S3∂O3∂Ws∂S3+∂O3∂L3∂S3∂O3∂S2∂S3∂Ws∂S2+∂O3∂L3∂S3∂O3∂S2∂S3∂S1∂S2∂Ws∂S1=∂O3∂L3W0(S2+S1Ws+S0Ws2)
关于上面这个多元复合函数链式求导过程,通过如下对变量层级树的遍历可以更加直观理解这一点:
可以看出对于
W
o
W_o
Wo 求偏导并没有长期依赖,但是对于
W
x
W_x
Wx、
W
s
W_s
Ws 求偏导,会随着时间序列产生长期依赖。因为
S
t
S_t
St 随着时间序列向前传播,而
S
t
S_t
St 又是
W
x
W_x
Wx、
W
s
W_s
Ws 的函数。
根据上述求偏导的过程,我们可以得出任意时刻对 W x W_x Wx、 W s W_s Ws 求偏导的公式:
∂ L t ∂ W x = ∑ k = 0 t ∂ L t ∂ O t ∂ O t ∂ S t ( ∏ j = k + 1 t ∂ S j ∂ S j − 1 ) ∂ S k ∂ W x \frac{\partial L_t}{\partial W_x} = \sum_{k=0}^{t} \frac{\partial L_t}{\partial O_t} \frac{\partial O_t}{\partial S_t} \left(\prod_{j=k+1}^{t} \frac{\partial S_j}{\partial S_{j-1}}\right) \frac{\partial S_k}{\partial W_x} ∂Wx∂Lt=k=0∑t∂Ot∂Lt∂St∂Ot j=k+1∏t∂Sj−1∂Sj ∂Wx∂Sk
任意时刻对 W s W_s Ws 求偏导的公式同上。
如果加上激活函数: S j = tanh ( W x X j + W s S j − 1 + b 1 ) S_j = \tanh(W_x X_j + W_s S_{j-1} + b_1) Sj=tanh(WxXj+WsSj−1+b1)
则 ∏ j = k + 1 t ∂ S j ∂ S j − 1 = ∏ j = k + 1 t tanh ′ W s \prod_{j=k+1}^{t} \frac{\partial S_j}{\partial S_{j-1}} = \prod_{j=k+1}^{t} \tanh' W_s j=k+1∏t∂Sj−1∂Sj=j=k+1∏ttanh′Ws
加上激活函数tanh复合后的多元链式求导过程如下图所示:
激活函数tanh和它的导数图像如下。
由上图可以看出 tanh ′ ≤ 1 \tanh' \leq 1 tanh′≤1,对于训练过程大部分情况下tanh的导数是小于1的,因为很少情况下会出现 W x X j + W s S j − 1 + b 1 = 0 W_x X_j + W_s S_{j-1} + b_1 = 0 WxXj+WsSj−1+b1=0,如果 W s W_s Ws 也是一个大于0小于1的值,则当t很大时 ∏ j = k + 1 t tanh ′ W s \prod_{j=k+1}^{t} \tanh' W_s ∏j=k+1ttanh′Ws,就会趋近于0,和 0.0 1 50 0.01^{50} 0.0150 趋近于0是一个道理。同理当 W s W_s Ws 很大时 ∏ j = k + 1 t tanh ′ W s \prod_{j=k+1}^{t} \tanh' W_s ∏j=k+1ttanh′Ws 就会趋近于无穷,这就是RNN中梯度消失和爆炸的原因。
至于怎么避免这种现象,再看看 ∂ L t ∂ W x = ∑ k = 0 t ∂ L t ∂ O t ∂ O t ∂ S t ( ∏ j = k + 1 t ∂ S j ∂ S j − 1 ) ∂ S k ∂ W x \frac{\partial L_t}{\partial W_x} = \sum_{k=0}^{t} \frac{\partial L_t}{\partial O_t} \frac{\partial O_t}{\partial S_t} \left(\prod_{j=k+1}^{t} \frac{\partial S_j}{\partial S_{j-1}}\right) \frac{\partial S_k}{\partial W_x} ∂Wx∂Lt=k=0∑t∂Ot∂Lt∂St∂Ot j=k+1∏t∂Sj−1∂Sj ∂Wx∂Sk 梯度消失和爆炸的根本原因就是 ∏ j = k + 1 t ∂ S j ∂ S j − 1 \prod_{j=k+1}^{t} \frac{\partial S_j}{\partial S_{j-1}} ∏j=k+1t∂Sj−1∂Sj 这一坨,要消除这种情况就需要把这一坨在求偏导的过程中去掉,至于怎么去掉,一种办法就是使 ∂ S j ∂ S j − 1 ≈ 1 \frac{\partial S_j}{\partial S_{j-1}} \approx 1 ∂Sj−1∂Sj≈1 另一种办法就是使 ∂ S j ∂ S j − 1 ≈ 0 \frac{\partial S_j}{\partial S_{j-1}} \approx 0 ∂Sj−1∂Sj≈0。其实这就是LSTM做的事情。
总结:
-
RNN 的梯度计算涉及到对激活函数的导数以及权重矩阵的连乘
- 以 sigmoid 函数为例,其导数的值域在 0 到 0.25 之间,当进行多次连乘时,这些较小的值相乘会导致梯度迅速变小。
- 如果权重矩阵的特征值也小于 1,那么在多个时间步的传递过程中,梯度就会呈指数级下降,导致越靠前的时间步,梯度回传的值越少。
-
由于梯度消失,靠前时间步的参数更新幅度会非常小,甚至几乎不更新。这使得模型难以学习到序列数据中长距离的依赖关系,对于较早时间步的信息利用不足,从而影响模型的整体性能和对序列数据的建模能力。
这里RNN的梯度消失,指的不是 ∂ L t ∂ W x \frac{\partial L_t}{\partial W_x} ∂Wx∂Lt梯度值接近于0,而是靠前时间步的梯度 ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S 2 ∂ S 2 ∂ S 1 ∂ S 1 ∂ W x \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial S_2} \frac{\partial S_2}{\partial S_1} \frac{\partial S_1}{\partial W_x} ∂O3∂L3∂S3∂O3∂S2∂S3∂S1∂S2∂Wx∂S1值算出来很小,也就是靠前时间步计算出来的结果对序列最后一个预测词的生成影响很小,也就是常说的RNN难以去建模长距离的依赖关系的原因;这并不是因为序列靠前的词对最后一个词的预测输出不重要,而是由于损失函数在把有用的梯度更新信息反向回传的过程中,被若干小于0的偏导连乘给一点点削减掉了。