当前位置: 首页 > article >正文

NLP 单双向RNN+LSTM+池化

# 单项RNN 从左往右顺序

# 双向RNN

import torch.nn as nn

import torch

# nn.RNN(3,3,1,batch_first=True,bidirectional=True)

# 双向单层RNN(Recurrent Neural Network)是一种特殊类型的循环神经网络,它能够在两个方向上处

# 理序列数据,即正向和反向。这使得网络在预测当前输出时,能够同时考虑到输入序列中当前元素之前

# 的信息和之后的信息。双向单层RNN由两个独立的单层RNN组成,一个负责处理正向序列(从开始到结

# 束),另一个负责处理反向序列(从结束到开始)。

# 每一层RNN由两个RNN构成 一个考虑输入一个考虑输出共同决定本层状态

# . RNN的训练方法——BPTT

# 随时间反向传播

# 本质还是BP算法,只不

# 过RNN处理时间序列数据,所以要基于时间反向传播

# RNN不善于处理远期依赖性任务  比如长句子

# 比较擅长 有序列的短句

# LSTM 长短期记忆网络

# 神经网络中不是单纯的一个层+激活函数了

# 而是多层+多个激活函数互相作用

# 核心是门控机制

# 基于 pytorch API 代码实现

# 在 LSTM 网络中,初始化隐藏状态 ( h0 ) 和细胞状态 ( c0 ) 是一个重要的步骤,确保模型在处理序列数据时有一个合理的起始状态。

def demo1():

    class Model(nn.Module):

        def __init__(self,input_size,hidden_size,output_size):

            super().__init__()

            self.output_size=output_size

            self.hidden_size=hidden_size

            self.lstm=nn.LSTM(input_size=input_size,hidden_size=self.hidden_size,batch_first=True)

            self.fc=nn.Linear(hidden_size,output_size)

           

        def forward(self,x):

            self.h0=torch.zeros(1,x.shape[0],self.hidden_size)

            self.c0=torch.zeros(1,x.shape[0],self.hidden_size)

            # lstm 的返回值为 out 和 (hn, cn) 。 out为隐藏状态 hn是最后一次的隐藏状态  cn为最后一次的细胞状态 output batch_first=True,形状为 (batch, seq_len, num_directions * hidden_size)

            x,_=self.lstm(x,(self.h0,self.c0))

            x=self.fc(x[-1][-1].view(x.shape[0],-1)) # 取最后一次的隐藏状态 当全连接的输出

            return x

   

    # 这里的假数据是词向量,不是序列下标 注意一下最后一个维度是input_size

    model1=Model(3,16,10)

    data=torch.rand(1,2,3)

    res=model1(data)

    print(res)

# 序列池化

# 在自然语言处理 (NLP) 中,序列池化(sequence pooling)是一种将变长序列转换为固定长度表示的方法。这个过程对于处理可变长度的输入(如句子或文档)特别有用,因为许多深度学习模型(如全连接层)需要固定长度的输入。

#  . 最大池化(Max Pooling):

# 一个批次中的每个字变为向量后 全部向量的对应维度 留下最大的那个

# 平均池化就是取平均值 nn.AdaptiveAvgPool1d

#

if __name__=="__main__":

    demo1()

    pass


http://www.kler.cn/a/512056.html

相关文章:

  • 农业农村大数据应用场景|珈和科技“数字乡村一张图”解决方案
  • Ubuntu 24.04 LTS 安装 Docker Desktop
  • 三电平空间矢量详解
  • (十五)WebGL中gl.texImage2D函数使用详解
  • 用户中心项目教程(二)---umi3的使用出现的错误
  • PCL K4PCS算法实现点云粗配准【2025最新版】
  • 苍穹外卖 项目记录 day07 商品缓存-购物车模块开发
  • [实战]Ubuntu使用工具和命令无法ssh,但使用另一台Ubuntu机器可以用命令ssh,非root用户。
  • 『 实战项目 』Cloud Backup System - 云备份
  • Kotlin 2.1.0 入门教程(五)
  • 【useImperativeHandle Hook】通过子组件暴露相应的属性和方法,实现在父组件中访问子组件的属性和方法
  • React 中hooks之useDeferredValue用法总结
  • 深度学习 | 基于 LSTM 模型的多电池健康状态对比及预测
  • 【柱状图】——18
  • 【玩转全栈】----Django制作部门管理页面
  • 基于SpringBoot的智能家居系统的设计与实现(源码+SQL脚本+LW+部署讲解等)
  • XAMPP运行没有创建桌面图标
  • ChatGPT被曝存在爬虫漏洞,OpenAI未公开承认
  • 2025年美国大学生数学建模竞赛赛前准备计划
  • 【技术杂谈】Arcgis调用天地图和卫星影像
  • Spring Web MVC 探秘
  • Nginx location 和 proxy_pass 配置详解
  • 后端开发流程学习笔记
  • Linux 高级路由与流量控制-用 tc qdisc 管理 Linux 网络带宽
  • 【通信协议】TCP通信
  • 记录一个简单小需求,大屏组件的收缩与打开,无脑写法