当前位置: 首页 > article >正文

解决pytorch问题:received an invalid combination of arguments - got

问题表现

今天跑模型时报了一个非常奇怪的错误:
在这里插入图片描述

意思是“你的lstm层输入的参数是无效的,要求接收参数的类型是(Tensor, tuple of (Tensor, Tensor), list of [Parameter, Parameter, Parameter, Parameter], float, int, float, bool, bool, bool),但是实际收到的是(Tensor, tuple of (Tensor, Tensor), list of [Parameter, Parameter, Parameter, Parameter], float, int, float, bool,”
为什么奇怪呢?奇怪在于我检查了我lstm的输入张量,size是正确的,并且我还检查了张量的值,确实都是float32类型:
在这里插入图片描述
代码如下

class StudentStatusCommunication(nn.Module):
    def __init__(self, embedding_size, add_size, hidden_size, layer_num, dropout):
        super(StudentStatusCommunication, self).__init__()

        self.linear_layer = nn.Linear(embedding_size+add_size, hidden_size)
        self.activate = nn.ReLU()
        self.loop_layer = nn.LSTM(hidden_size, hidden_size, layer_num, dropout, batch_first=True)

    def forward(self, embedding, add_embedding):
        e = torch.cat((embedding, add_embedding), dim=-1)
        print(f'cat: {e.shape}')
        e = self.linear_layer(e)
        print(f'linear: {e.shape}')
        e = self.activate(e)
        print(f'activate: {e.shape}')
        print(e)

        h_steps, (h_t, c_t) = self.loop_layer(e)
        # 隐藏状态列表, 最后时刻隐藏状态, 最后时刻细胞状态
        return h_t
    

解决方案

问题已经写在报错里了,是参数调用错误,那么参数调用错误在哪呢?
在这里插入图片描述
python在调用函数时,接收参数的方式有两种:1、顺序接收(layer_num);2、通过命名接收(batch_first)。
两种方式可以混用,但有两个前提:

  1. 命名接收方式之后不能再用顺序接收;
  2. 顺序接收前n个参数会依次传递给函数的前n个形参。

现在我们检查nn.LSTM接收参数的顺序:
在这里插入图片描述
可以发现前三个参数都进行了正确地传递,但是第四个参数dropout错误传递给了bias,这就是错误来源!没错,python不支持省略的命名传参方式,即:当实参和形参命名一样时,允许省略命名指向(dropout=dropout可以允许直接写成dropout)。好好好,又是被多语言思维惯性打败的一天
修正这个错误也非常简单,只要添加命名指向就行:

self.loop_layer = nn.LSTM(hidden_size, hidden_size, layer_num, dropout=dropout, batch_first=True)

http://www.kler.cn/a/376578.html

相关文章:

  • WPF+MVVM案例实战(十八)- 自定义字体图标按钮的封装与实现(ABD类)
  • iptables面试题
  • Wails不同平台打包
  • 人工智能基础-opencv-图像处理篇
  • Java实现图片转pdf
  • kafka相关面试题
  • MFC图形函数学习03——画直线段函数
  • 【系统架构】如何演变系统架构:从单体到微服务
  • 前端好用的网站分享——CSS(持续更新中)
  • Three.js 开源项目及入门教程分享
  • 【MySql】-0.1、Unbunt20.04二进制方式安装Mysql5.7和8.0
  • Python中os.mkdir() 和 os.makedirs()有什么不同
  • 3DDFA-V3——基于人脸分割几何信息指导下的三维人脸重建
  • WebSocket详解:从前端到后端的全栈理解
  • 【android12】【AHandler】【4.AHandler原理篇ALooper类方法全解】
  • 基于openEuler22.03的rpcapd抓包机安装
  • 如何为STM32的ADC外设编写中断服务程序
  • Linux权限管理和文件属性
  • Docker:技术架构的演进之路
  • 安卓应用自动化测试工具Appium实操分享
  • 【数据结构-邻项消除】力扣1003. 检查替换后的词是否有效
  • 笔记本电脑死机恢复按什么键恢复 电脑死机的解决方法
  • Python 淘宝数据挖掘与词云图制作全攻略
  • Redis特性和应用场景以及安装
  • 私有化视频平台EasyCVR海康大华宇视视频平台视频诊断技术是如何实时监测视频质量的?
  • 在 Windows 系统上设置 MySQL8.0以支持远程连接