七、传统循环神经网络(RNN)
传统循环神经网络 RNN
- 前言
- 一、RNN 是什么?
- 1.1 RNN 的结构
- 1.2 结构举例
- 二、RNN 模型的分类
- 2.1 按照 输入跟输出 的结构分类
- 2.2 按照 内部结构 分类
- 三、传统 RNN 模型
- 3.1 RNN内部结构图
- 3.2 内部计算公式
- 3.3 其中 tanh 激活函数的作用
- 3.4 传统RNN优缺点
- 四、代码演示
- 总结
前言
- 前面我们学习了卷积神经网络CNN,通过对图像做卷积运算来提取到图片的局部特征,但是在文本中,我们该怎么对文本进行张量转换,并且让机器学习到文本前后的联系呢,接下里我们将对文本领域的循环神经网络进行讲解。
一、RNN 是什么?
- RNN(Recurrent Neural Network),中文称作循环神经网络:它一般以序列数据为输入, 通过网络内部的结构设计有效捕捉序列之间的关系特征,一般也是以序列形式进行输出。
1.1 RNN 的结构
- 一般单层神经网络结构:
- RNN单层网络结构:
- 以时间步对RNN进行展开后的单层网络结构:
RNN 的循环机制能够使模型的隐藏层上一步产生的结果,作为当下时间步输入的一部分(当下时间步的输入除了正常输入之外,还包括上一步的隐层输出)对当下时间步的输出产生影响。
1.2 结构举例
- 我们前面说RNN能有效捕捉序列之间的关系特征,下边我们举个例子:
- 假如用户输入了一段话:“What time is it ?”,那么机器是怎么捕捉他们之间的序列关系的呢?
- 第一步:先对输入的 “What time is it ?” 进行分词,因为RNN是按照输入顺序来工作的,每次都接受一个单词进行处理。
- 第二步:首先将一个单词 “What” 输入进RNN,他将产生一个输出
O
1
O1
O1
- 第三步:继续将单词 “time” 输入到 RNN,此时 RNN 不仅利用 “time” 来产生
O
2
O2
O2,还会使用上次隐藏层的输出
O
1
O1
O1作为输入信息
- 第四步 :重复第三步,知道将所有单词输入
- 第五步:最后将隐藏层的输出
O
5
O5
O5 进行处理来理解用户意图
- 第一步:先对输入的 “What time is it ?” 进行分词,因为RNN是按照输入顺序来工作的,每次都接受一个单词进行处理。
二、RNN 模型的分类
2.1 按照 输入跟输出 的结构分类
- N vs N
- 这种结构是RNN最基础的机构形式,最大的特点就是:输入跟输出序列是等长的
- 由于这种限制的存在,使其适用范围较小,可以用于生成的等长的诗句或对联。
- N vs 1
- 当我们输入的问题是一个序列,而要求输出是单个的一个值而不是序列,这时候我们就要在最后一个隐藏层的输出上进行线性变化了。
- 大部分情况下,为了更好的明确结果,还要使用 sigmoid 或者 softmax 进行处理,这样的结构经常用于文本分类问题上。
- 1 vs N
- 如果输入的不是一个序列,而要求输出是一个序列,那我们就要让每次的输入都作用到每次的输出上
- 一般用来将图片生成文字任务、
- N vs M
- 这是一种不限输入输出长度的RNN结构,它由编码器和解码器两部分组成,两者的内部结构都是某类RNN,它也被称为 seq2seq 架构
- 输入数据首先通过编码器,最终输出一个隐含变量
c
c
c,之后最常用的做法是使用这个隐含变量
c
c
c 作用在解码器进行解码的每一步上,以保证输入信息被有效利用。
- seq2seq架构最早被提出应用于机器翻译,因为其输入输出不受限制,如今也是应用最广的RNN模型结构。
- 在机器翻译, 阅读理解, 文本摘要等众多领域都进行了非常多的应用实践。
2.2 按照 内部结构 分类
- 我们先介绍分为几种,对于其工作原理
- 在之后的章节中,我们再进行详细讨论
- 传统RNN
- LSTM
- Bi-LSTM
- GRU
- Bi-GRU
三、传统 RNN 模型
3.1 RNN内部结构图
- 解释:
- 隐藏层也就是循环层接收到的是当前时间步的输入 X t X_t Xt 和上个时间步的隐藏层的输出 h t − 1 h_{t-1} ht−1
- 这两个进入RNN结构体中,各自有跟权重矩阵进行运算以后,会融合到一起(也就是拼接到一起),形成新的张量 [ X t , h t − 1 ] [X_t , h_{t-1}] [Xt,ht−1]
- 之后这个张量经过一个全连接层(线性层),该层使用 tanh 作为激活函数,最终得到当前时间步的输出 h t h_t ht
- 最后,当前时间步的输出 h t h_t ht 将和 下一个时间步的输入 X t + 1 X_{t+1} Xt+1 一起进入结构体
3.2 内部计算公式
h
t
=
tanh
(
X
t
W
i
h
T
+
b
i
h
+
h
t
−
1
W
h
h
T
+
b
h
h
)
h_t = \tanh(X_t W_{ih}^T + b_{ih} + h_{t-1}W_{hh}^T + b_{hh})
ht=tanh(XtWihT+bih+ht−1WhhT+bhh)
- W i h W_{ih} Wih 表示输入数据的权重
- b i h b_{ih} bih 表示输入数据的偏置
- W h h W_{hh} Whh 表示隐藏状态的权重
- b h h b_{hh} bhh 表示隐藏状态的偏置
3.3 其中 tanh 激活函数的作用
- 非线性映射:
- RNN中的线性层(也称为全连接层或仿射变换)仅仅是对输入进行线性组合,而tanh函数则引入了非线性特性。这使得RNN能够学习和表示更复杂的输入-输出关系,因为非线性映射能够捕捉数据中的非线性特征。
- 值域限制:
- tanh函数的输出值域为(-1, 1),这有助于将神经元的输出限制在一个合理的范围内。与sigmoid函数类似,tanh函数也能够在一定程度上缓解梯度消失的问题(尽管在非常深的网络中仍然可能存在),因为梯度在值域内不会趋于零。
- 中心化输出:
- tanh函数的输出是中心化的,即均值为0。这有助于在训练过程中保持数据的分布相对稳定,有助于加快收敛速度和提高模型的稳定性。
- 梯度传播:
- 在反向传播过程中,tanh函数的导数(即梯度)在输入接近0时最大,而在输入接近-1或1时接近0。这意味着当神经元的输出接近极端值时,梯度会变小,可能导致梯度消失问题。
3.4 传统RNN优缺点
- 优势
- 由于内部结构简单,对计算资源要求低,相比之后我们要学习的RNN变体:LSTM和GRU模型参数总量少了很多,在短序列任务上性能和效果都表现优异
- 缺点
- 传统RNN在解决长序列之间的关联时,通过实践,证明经典RNN表现很差,原因是在进行反向传播的时候,过长的序列导致梯度的计算异常,发生梯度消失或爆炸。
四、代码演示
演示代码 1 :
import torch
from torch import nn
def my_rnn_dm01():
'''
RNN 的三个参数的含义
第一个参数:input_size(输入张量 x 的维度)
第二个参数:hidden_size(隐藏层的维度,隐藏层的神经元个数)
第三个参数:num_layer(隐藏层的数量)
'''
rnn = nn.RNN(5, 6, 1)
'''
input 的三个参数的含义
第一个参数:sequence_length(输入序列的长度)
第二个参数:batch_size(批次的样本数量)
第三个参数:input_size(输入张量的维度)
'''
input = torch.randn(5, 3, 5)
'''
output 的三个参数的含义
第一个参数:num_layer * num_directions(层数*网络方向)
第二个参数:batch_size(批次的样本数)
第三个参数:hidden_size(隐藏层的维度, 隐藏层神经元的个数)
'''
# h0 = torch.randn(1, 5, 6)
# output, hn = rnn(input, h0) # h0 可以传也可以不传
output, hn = rnn(input)
print(output.shape) # torch.Size([5, 3, 6])
print(output)
# print(hn.shape)
# print(hn)
演示代码 2 :
def my_rnn_dm02():
'''
RNN 的三个参数的含义
第一个参数:input_size(输入张量 x 的维度)
第二个参数:hidden_size(隐藏层的维度,隐藏层的神经元个数)
第三个参数:num_layer(隐藏层的数量)
第四个参数:输入层可以把 batch_size参数 放在一个位置
'''
rnn = nn.RNN(5, 6, 1, batch_first=True)
'''
input 的三个参数的含义
第一个参数:batch_size(批次的样本数量)
第二个参数:sequence_length(输入序列的长度)
第三个参数:input_size(输入张量的维度)
'''
input = torch.randn(3, 20, 5)
'''
output 的三个参数的含义
第一个参数:num_layer * num_directions(层数*网络方向)
第二个参数:batch_size(批次的样本数)
第三个参数:hidden_size(隐藏层的维度, 隐藏层神经元的个数)
'''
output, hn = rnn(input)
print(output.shape) # torch.Size([3, 20, 6])
print(output)
总结
- 以上就是传统RNN的基本内容