当前位置: 首页 > article >正文

pytorch实现RNN网络

目录

1.导包

2. 加载本地文本数据

 3.构建循环神经网络层

4.初始化隐藏状态state

5.创建随机的数据,检测一下代码是否能正常运行

6. 构建一个完整的循环神经网络¶ 

7.模型训练 

8.个人知识点理解


 

1.导包

import torch
from torch import nn
from torch.nn import functional as F
import dltools

2. 加载本地文本数据

#声明变量:批次大小(一批所取的数据量)、子序列的长度
batch_size, num_steps =32, 35
#获取训练数据的迭代器, 词汇表
train_iter, vocab = dltools.load_data_time_machine(batch_size=batch_size, num_steps=num_steps)

 3.构建循环神经网络层

#声明变量:隐藏层的神经元数量(每个神经元都会有一个输出)
num_hiddens = 256
#构建一个具有256个隐藏单元的单隐藏层的循环神经网络
#num_layers=1默认值:一层神经网络
rnn_layer = nn.RNN(input_size=len(vocab), hidden_size=num_hiddens, num_layers=1)

4.初始化隐藏状态state

# 括号中的1:因为num_layers=1默认值:一层神经网络
state = torch.zeros((1, batch_size, num_hiddens))
state.shape
torch.Size([1, 32, 256])

5.创建随机的数据,检测一下代码是否能正常运行

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
#传入X和初始化时的state,获取Y和state_new
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape


#有输出表示代码正常运行!!!

 (torch.Size([35, 32, 256]), torch.Size([1, 32, 256])) 

6. 构建一个完整的循环神经网络¶ 

.long() 方法‌:这是PyTorch张量的一个方法,用于将张量的数据类型转换为torch.long。torch.long是一种整数数据类型,通常用于索引或存储不需要浮点数精度的整数数据。 

class RNNModel(nn.Module):   #继承nn.Module
    #初始化(需要用到的)参数,  **kwargs表示继承的其他参数(不一一写明的意思)
    #vocab_size = len(vocab)
    def __init__(self, rnn_layer, vocab_size, **kwargs):
        #继承父类的属性和方法
        super().__init__(**kwargs)
        self.rnn_layer = rnn_layer
        #词汇表的长度
        self.vocab_size =vocab_size
        self.num_hiddens = self.rnn_layer.hidden_size
        
        #判断是否为双向循环
        if not self.rnn_layer.bidirectional:
            self.num_directions = 1
            #nn.Linear用于定义线性层的类,一般用于全连接层
            self.linear = nn.Linear(in_features=self.num_hiddens, out_features=self.vocab_size)
        else:
            self.num_directions = 2
            self.linear = nn.Linear(self.num_hiddens*2, self.vocab_size)
    
    #定义了数据在模型中的前向传播过程。(串联每一件事件的逻辑顺序)
    def forward(self, inputs, state):
        #one_hot编码,处理输入的X数据,此时的X.shape=(batch_size, num_steps)
        #。T转置之后,X.shape=(num_steps,batch_size)
        #one_hot编码之后, X.shape=(num_steps,batch_size, len(vocab)
        X = F.one_hot(inputs.T.long(), self.vocab_size)
        #将数据转化为tensor
        X = X.to(torch.float32)
        Y, state = self.rnn_layer(X, state)
        #此时,Y.shape = torch.Size(num_steps, batch_size, num_hiddens)
        
        #输出层:Y.shape必须是一个二维的, -1表示合并Y.shape中的num_steps与batch_size,
        outputs = self.linear(Y.reshape(-1, Y.shape[-1]))
        return outputs, state
                              
   # 初始化隐藏状态
    def begin_state(self, device, batch_size=1):
        return torch.zeros((self.num_directions * self.rnn_layer.num_layers, batch_size, self.num_hiddens), device=device)
#在训练之前,基于随机初始化的权重进行预测,测试模型
device = dltools.try_gpu()
rnn_net = RNNModel(rnn_layer, vocab_size=len(vocab))
rnn_net = rnn_net.to(device)
dltools.predict_ch8(prefix='time traveller',
                    num_preds=10, 
                    net=rnn_net, 
                    vocab=vocab, 
                    device=device)
'time travellergghhhhhhhh'

7.模型训练 

#声明变量
#模型训练时,可以先让学习率的值稍大一些,让梯度下降的快一些,然后
#梯度下降到一定程度再改成较小的值
num_epochs, lr = 500, 0.1
dltools.train_ch8(net=rnn_net, 
                  train_iter=train_iter, 
                  vocab=vocab, 
                  lr=lr, 
                  num_epochs=num_epochs, 
                  device=device)

 

8.个人知识点理解

 

 

 


http://www.kler.cn/news/312993.html

相关文章:

  • Vue使用qrcodejs2-fix生成网页二维码
  • 解决 GitLab CI/CD 中的 `413 Request Entity Too Large` 错误
  • 生信初学者教程(五):R语言基础
  • 【计算机网络篇】电路交换,报文交换,分组交换
  • BGP实验
  • Percona发布开源DBaaS平台;阿里云RDS发布全球多活数据库(GAD);Redshift支持自然语言生成SQL
  • Pyspark dataframe基本内置方法(4)
  • 【有啥问啥】弱监督学习新突破:格灵深瞳多标签聚类辨别(Multi-Label Clustering and Discrimination, MLCD)方法
  • QT 将文字矢量化,按照设置的宽和高绘制
  • 3657A/B/AM/BM矢量网络分析仪
  • CSS - 通用左边图片,右边内容,并且控制长度溢出处理模板(vue | uniapp | 微信小程序)
  • python画图|曲线分段设置颜色基础教程
  • 什么是3D展厅?有何优势?怎么制作3D展厅?
  • 蓝星多面体foc旋钮键盘复刻问题详解
  • JVM java主流的追踪式垃圾收集器
  • docker 镜像,导入导出,
  • 【数据结构入门】排序算法之三路划分与非比较排序
  • 基于OpenCV的YOLOv5图片检测
  • 寄存器二分频电路
  • Serverless架构
  • 【C/C++语言系列】实现单例模式
  • golang学习笔记23——golang微服务中服务间通信问题探讨
  • 【ShuQiHere】 探索 IEEE 754 浮点数标准:以 57.625 和 -57.625 为例
  • 【bugfix】-洽谈回填的图片消息无法显示
  • 0基础学习HTML(八)头部
  • PyCharm部分快捷键冲突问题
  • Pybullet 安装过程
  • 利士策分享,周末时光:一场自我充实的精致规划
  • python学习-10【模块】
  • C#开源的一个能利用Windows通知栏背单词的软件