当前位置: 首页 > article >正文

跟着问题学15——GRU网络结构详解及代码实战

   

1 RNN的缺陷——长期依赖的问题 (The Problem of Long-Term Dependencies)

前面一节我们学习了RNN神经网络,它可以用来处理序列型的数据,比如一段文字,视频等等。RNN网络的基本单元如下图所示,可以将前面的状态作为当前状态的输入。

但也有一些情况,我们需要更“长期”的上下文信息。比如预测最后一个单词“我在中国长大……我说一口流利的**。”“短期”的信息显示,下一个单词很可能是一种语言的名字,但如果我们想缩小范围,我们需要更长期语境——“我在中国长大”,但这个相关信息与需要它的点之间的距离完全有可能变得非常大。

不幸的是,随着这种距离的扩大,RNN无法学会连接这些信息。

从理论上讲,RNN绝对有能力处理这种“长期依赖性”。人们可以为他们精心选择参数,以解决这种形式的问题。遗憾的是,在实践中,RNN似乎无法学习它们。

幸运的是,GRU也没有这个问题!

2、GRU

什么是GRU

GRU(Gate Recurrent Unit)是循环神经网络(Recurrent Neural Network, RNN)的一种。和LSTM(Long-Short Term Memory)一样,也是为了解决长期记忆和反向传播中的梯度等问题而提出来的。

GRU和LSTM在很多情况下实际表现上相差无几,那么为什么我们要使用新人GRU(2014年提出)而不是相对经受了更多考验的LSTM(1997提出)呢。

用论文中的话说,相比LSTM,使用GRU能够达到相当的效果,并且相比之下更容易进行训练,能够很大程度上提高训练效率,因此很多时候会更倾向于使用GRU。

2.1总体结构框架

前面我们讲到,神经网络的各种结构都是为了挖掘变换数据特征的,所以下面我们也将结合数据特征的维度来对比介绍一下RNN&&LSTM的网络结构。

多层感知机(线性连接层)结构

从特征角度考虑:

输入特征:是n*1的单维向量(这也是为什么卷积神经网络在linear层前要把所有特征层展平),

隐藏层:然后根据隐藏层神经元的数量m将前层输入的特征用m*1的单维向量进行表示(对特征进行了提取变换,隐藏层的数据特征),单个隐藏层的神经元数量就代表网络参数,可以设置多个隐藏层;

输出特征:最终根据输出层的神经元数量y输出y*1的单维向量。

卷积神经网络结构

 从特征角度考虑:

输入特征:是(batch)*channel*width*height的张量,

卷积层(等):然后根据输入通道channel的数量c_in和输出通道channel的数量c_out会有c_out*c_in*k*k个卷积核将前层输入的特征进行卷积(对特征进行了提取变换,k为卷积核尺寸),卷积核的大小和数量c_out*c_in*k*k就代表网络参数,可以设置多个卷积层;每一个channel都代表提取某方面的一种特征,该特征用width*height的二维张量表示,不同特征层之间是相互独立的(可以进行融合)。

输出特征:根据场景的需要设置后面的输出,可以是多分类的单维向量等等。

循环神经网络RNN系列结构

从特征角度考虑:

输入特征:是(batch)*T_seq*feature_size的张量(T_seq代表序列长度,注意不是batch_size).

我们来详细对比一下卷积神经网络的输入特征,

(batch)*T_seq*feature_size

(batch)*channel*width*height,

逐个进行分析,RNN系列的基础输入特征表示是feature_size*1的单维向量,比如一个单词的词向量,比如一个股票价格的影响因素向量,而CNN系列的基础输入特征是width*height的二维张量;

再来看一下序列T_seq和通道channel,RNN系列的序列T_seq是指一个连续的输入,比如一句话,一周的股票信息,而且这个序列是有时间先后顺序且互相关联的,而CNN系列的通道channel则是指不同角度的特征,比如彩色图像的RGB三色通道,过程中每个通道代表提取了每个方面的特征,不同通道之间是没有强相关性的,不过也可以进行融合。

最后就是batch,两者都有,在RNN系列,batch就是有多个句子,在CNN系列,就是有多张图片(每个图片可以有多个通道)

隐藏层:明确了输入特征之后,我们再来看看隐藏层代表着什么。隐藏层有T_seq个隐状态H_t(和输入序列长度相同),每个隐状态H_t类似于一个channel,对应着T_seq中的t时刻的输入特征;而每个隐状态H_t是用hidden_size*1的单维向量表示的,所以一个隐含层是T_seq*hidden_size的张量;对应时刻t的输入特征由feature_size*1变为hidden_size*1的向量。如图中所示,同一个隐含层不同时刻的参数W_ih和W_hh是共享的;隐藏层可以有num_layers个(图中只有1个)

以t时刻具体阐述一下:

X_t是t时刻的输入,是一个feature_size*1的向量

W_ih是输入层到隐藏层的权重矩阵

H_t是t时刻的隐藏层的值,是一个hidden_size*1的向量

W_hh是上一时刻的隐藏层的值传入到下一时刻的隐藏层时的权重矩阵

Ot是t时刻RNN网络的输出

从上右图中可以看出这个RNN网络在t时刻接受了输入Xt之后,隐藏层的值是St,输出的值是Ot。但是从结构图中我们可以发现St并不单单只是由Xt决定,还与t-1时刻的隐藏层的值St-1有关。

2.2 GRU的输入输出结构

GRU的输入输出结构与普通的RNN是一样的。有一个当前的输入xt,和上一个节点传递下来的隐状态(hidden state)ht-1 ,这个隐状态包含了之前节点的相关信息。结合xt和 ht-1,GRU会得到当前隐藏节点的输出yt 和传递给下一个节点的隐状态 ht。

图 GRU的输入输出结构

那么,GRU到底有什么特别之处呢?下面来对它的内部结构进行分析!

2.3 GRU的内部结构

不同于LSTM有3个门控,GRU仅有2个门控,

第一个是“重置门”(reset gate),其根据当前时刻的输入xt和上一时刻的隐状态ht-1变换后经sigmoid函数输出介于0和1之间的数字,用于将上一时刻隐状态ht-1重置为ht-1’,即ht-1’=ht-1*r。

再将ht-1’与输入xt进行拼接,再通过一个tanh激活函数来将数据放缩到-1~1的范围内。即得到如下图2-3所示的h’。

第一个是“更新门”(update gate),其根据当前时刻的输入xt和上一时刻的隐状态ht-1变换后经sigmoid函数输出介于0和1之间的数字,

最终的隐状态ht的更新表达式即为:

再次强调一下,门控信号(这里的z)的范围为0~1。门控信号越接近1,代表”记忆“下来的数据越多;而越接近0则代表”遗忘“的越多。

2.4 小结

GRU很聪明的一点就在于,使用了同一个门控z就同时可以进行遗忘和选择记忆(LSTM则要使用多个门控)。与LSTM相比,GRU内部少了一个”门控“,参数比LSTM少,但是却也能够达到与LSTM相当的功能。考虑到硬件的计算能力时间成本,因而很多时候我们也就会选择更加”实用“的GRU。

3代码

import torch
import torch.nn as nn


def my_gru(input,initial_states,w_ih,w_hh,b_ih,b_hh):
    h_prev=initial_states
    batch_size,T_seq,feature_size=input.shape
    hidden_size=w_ih.shape[0]//3

    batch_w_ih=w_ih.unsqueeze(0).tile(batch_size,1,1)
    batch_w_hh=w_hh.unsqueeze(0).tile(batch_size,1,1)

    output=torch.zeros(batch_size,T_seq,hidden_size)

    for t in range(T_seq):
        x=input[:,t,:]
        w_times_x=torch.bmm(batch_w_ih,x.unsqueeze(-1))
        w_times_x=w_times_x.squeeze(-1)

       # print(batch_w_hh.shape,h_prev.shape)
        # 计算两个tensor的矩阵乘法,torch.bmm(a,b),tensor a 的size为(b,h,w),tensor b的size为(b,w,m)
        # 也就是说两个tensor的第一维是相等的,然后第一个数组的第三维和第二个数组的第二维度要求一样,
        # 对于剩下的则不做要求,输出维度 (b,h,m)
        # batch_w_hh=batch_size*(3*hidden_size)*hidden_size
        # h_prev=batch_size*hidden_size*1
        # w_times_x=batch_size*hidden_size*1
        ##squeeze,在给定维度(维度值必须为1)上压缩维度,负数代表从后开始数
        w_times_h_prev=torch.bmm(batch_w_hh,h_prev.unsqueeze(-1))
        w_times_h_prev=w_times_h_prev.squeeze(-1)

        r_t=torch.sigmoid(w_times_x[:,:hidden_size]+w_times_h_prev[:,:hidden_size]+b_ih[:hidden_size]
                          +b_hh[:hidden_size])
        z_t=torch.sigmoid(w_times_x[:,hidden_size:2*hidden_size]+w_times_h_prev[:,hidden_size:2*hidden_size]
                          +b_ih[hidden_size:2*hidden_size]+b_hh[hidden_size:2*hidden_size])
        n_t=torch.tanh(w_times_x[:,2*hidden_size:3*hidden_size]+w_times_h_prev[:,2*hidden_size:3*hidden_size]
                          +b_ih[2*hidden_size:3*hidden_size]+b_hh[2*hidden_size:3*hidden_size])

        h_prev=(1-z_t)*n_t+z_t*h_prev
        output[:,t,:]=h_prev

    return output,h_prev


if __name__=="__main__":

    fc=nn.Linear(12,6)
   

    batch_size=2
    T_seq=5
    feature_size=4

    hidden_size=3
   # output_feature_size=3

    input=torch.randn(batch_size,T_seq,feature_size)
    h_prev=torch.randn(batch_size,hidden_size)

    gru_layer=nn.GRU(feature_size,hidden_size,batch_first=True)
    output,h_final=gru_layer(input,h_prev.unsqueeze(0))
    # for k,v in gru_layer.named_parameters():
    #     print(k,v.shape)
    # print(output,h_final)

    my_output, my_h_final=my_gru(input,h_prev,gru_layer.weight_ih_l0,gru_layer.weight_hh_l0,gru_layer.bias_ih_l0,gru_layer.bias_hh_l0)

    # print(my_output, my_h_final)
    # print(torch.allclose(output,my_output))

参考资料

https://zhuanlan.zhihu.com/p/32481747

https://speech.ee.ntu.edu.tw/~tlkagk/courses/MLDS_2018/Lecture/Seq%20(v2).pdf

https://www.bilibili.com/video/BV1jm4y1Q7uh/?spm_id_from=333.788&vd_source=cf7630d31a6ad93edecfb6c5d361c659


http://www.kler.cn/a/428477.html

相关文章:

  • Flutter:搜索页,搜索bar封装
  • 【玩转全栈】----Django基本配置和介绍
  • STM32更新程序OTA
  • 备赛蓝桥杯之第十五届职业院校组省赛第二题:分享点滴
  • 【2024 博客之星评选】请继续保持Passion
  • FortiGate配置远程拨号VPN
  • 计算机毕业设计hadoop+spark+hive图书推荐系统 豆瓣图书数据分析可视化大屏 豆瓣图书爬虫 知识图谱 图书大数据 大数据毕业设计 机器学习
  • 【集群划分】含分布式光伏的配电网集群电压控制【33节点】
  • 入门Web自动化测试之元素定位进阶技巧
  • 用二维图像渲染3D场景视频
  • 《图神经网络编程实战:开启深度学习新领域》
  • Android显示系统(08)- OpenGL ES - 图片拉伸
  • 基于拼团社交与开源链动 2+1 模式 S2B2C 商城小程序源码的营销创新策略研究
  • TokenFormer: Rethinking Transformer Scaling with Tokenized Model Parameters
  • Vant UI +Golang(gin) 上传文件
  • Connection对象,Statement对象和ResultSet对象的依赖关系 JDBC
  • 设计模式学习思路二
  • linux内核网络层学习
  • spark-operaotr
  • unity3d—demo(实现给出图集名字和图片名字生成对应的图片)
  • 腾讯云 AI 代码助手:AI Agent 构建企业新一代研发范式
  • 时频转换 | Matlab实小波变换Real wavelet transform一维数据转二维图像方法
  • 使用Jackson库美化JSON输出
  • Llama-3.1-405B-Instruct 开源体验|对比|亮点|使用|总结
  • PDF拆分之怎么对批量的PDF文件进行分割-免费PDF编辑工具分享
  • 【问题解决方案】项目路径更改后pycharm选定解释器无效