当前位置: 首页 > article >正文

解决 TypeError: Expected state_dict to be dict-like, , got <class ‘*‘>.

这是一个简洁的错误复现和解决文章

文章目录

    • 错误原因
    • 错误重现
    • 正确加载演示
    • 拓展阅读

错误原因

一般是因为混合使用不同的保存和加载方式,问题出在你用 load_state_dict() 去加载别人使用torch.save(model) 保存的整个模型。

错误重现

下面我们来复现它,看是不是和你的操作一致:

  1. 错误地保存整个 model 而不是其 state_dict
    import torch
    import torch.nn as nn
    
    # 定义一个线性模型进行演示
    class LinearModel(nn.Module):
        def __init__(self, input_size, output_size):
            super(LinearModel, self).__init__()
            self.linear = nn.Linear(input_size, output_size)
    
        def forward(self, x):
            return self.linear(x)
    
    # 创建模型实例
    model = LinearModel(input_size=10, output_size=1)
    
    # 打印模型结构
    print("Model:", model)
    
    # 保存模型的 state_dict
    torch.save(model.state_dict(), './linear_model_state_dict.pth')
    
  2. 加载时传入 model 对象:
    # 创建一个新的模型实例
    new_model = LinearModel(input_size=10, output_size=1)
    
    # 加载 state_dict 到新模型
    new_model.load_state_dict(torch.load('./linear_model_state_dict.pth'))
    
    # 打印加载后的新模型结构
    print("Model loaded with state_dict:", new_model)
    
    输出
    Error: Expected state_dict to be dict-like, got <class '__main__.LinearModel'>.
    

正确加载演示

下面是两种保存和加载的方法,任选其一即可。

import torch
import torch.nn as nn

# 定义一个线性模型
class LinearModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(LinearModel, self).__init__()
        self.linear = nn.Linear(input_size, output_size)

    def forward(self, x):
        return self.linear(x)

# 创建模型实例
model = LinearModel(input_size=10, output_size=1)
print("Model:", model)

# 方法 1:保存和加载 state_dict
# 保存模型的 state_dict
torch.save(model.state_dict(), './linear_model_state_dict.pth')

# 创建一个新的模型实例
new_model = LinearModel(input_size=10, output_size=1)

# 加载 state_dict 到新模型
new_model.load_state_dict(torch.load('./linear_model_state_dict.pth'))

# 方法 2:保存和加载整个模型
# 保存整个模型
torch.save(model, './linear_model.pth')

# 加载整个模型
loaded_model = torch.load('./linear_model.pth')

拓展阅读

PyTorch 模型保存与加载的三种常用方式


http://www.kler.cn/news/326147.html

相关文章:

  • Acwing 最小生成树
  • 每日OJ题_牛客_NC40链表相加(二)_链表+高精度加法_C++_Java
  • 《黑神话:悟空》天命人速通法宝 | 北通鲲鹏20智控游戏手柄评测
  • linux打开桌面软件(wps)、获取已打开的文件名(wps)
  • Ini文件读写配置工具类 - C#小函数类推荐
  • 汽车免拆诊断案例 | 2016 款宾利GT车仪表盘上的多个故障灯点亮
  • 使用TensorFlow实现一个简单的神经网络:从入门到精通
  • 动手学深度学习(李沐)PyTorch 第 3 章 线性神经网络
  • TiDB 性能测试的几个优化点
  • Leetcode热题100-438 找出字符串中所有字母异位数
  • R语言非参数回归预测摩托车事故、收入数据:局部回归、核回归、LOESS可视化...
  • 408算法题leetcode--第19天
  • java通过webhook给飞书发送群消息
  • PTA L1-080 乘法口诀数列
  • C语言线程编程深度解析
  • Elasticsearch UNASSIGNED 怎么修复
  • OJ在线评测系统 后端 用策略模式优化判题机架构
  • MySQL基础篇 - 约束
  • Eclipse Memory Analyzer (MAT)提示No java virtual machine was found ...解决办法
  • Altium Designer脚本的执行方式
  • 【漏洞复现】VEXUS多语言货币交易所存在未授权访问漏洞
  • centos已安装python3.7环境,还行单独安装python3.10环境,如何安装,具体步骤
  • 进程、线程、协程详解:并发编程的三大武器
  • websocket初识
  • 数据集-目标检测系列-兔子检测数据集 rabbit >> DataBall
  • 中国资产“超级星期四”之后,腰部中概股或成增长“黑马”
  • Linux云计算 |【第四阶段】PROJECT2-DAY1
  • 如何使用开发者工具捕获鼠标右键点击事件
  • Tensorflow2.0
  • Spring Boot 进阶-深入了解SpringBoot条件注解