当前位置: 首页 > article >正文

(已解决)torch.load的时候发生错误ModuleNotFoundError: No module named ‘models‘

文章目录

      • 背景
      • 原因
      • 解决方案

背景

很简单,我网上下载了一个模型文件,现在想读取这个模型,然后将这个模型用在我的数据集上。

import torch
model=torch.load("model.pyt")#这步直接报错了。
output=model(mydata)

报错了。

ModuleNotFoundError: No module named ‘models‘

原因

我现在项目的目录结构和model.pyt这个模型文件当初在保存torch.save的时候的项目目录不一致,导致导入load模型的时候有一些关键东西缺失。

啥意思呢?假设当初模型保存torch.save的文件长成这样。

import torch
from A import Model
model=Model()
torch.save(model,"model.pyt")

保存在model.pyt中的东西,大家都知道,有模型权重,模型结构等。但是大家想过这样一个问题没有,如果模型里面的一个函数引用了另外一个用户自定义的函数,在torch.save之后,这个自定义函数会被保存吗?答案是不会被保存。也就是说,对于上面的代码,

import torch
from A import Model

torch这个库不会被保存,A这个文件也不会被保存。那么自然,等我们torch.load的时候,A就会找不到,torch可以找到,因为我们本地肯定会导入torch

为啥Pytorch设计的时候不保存这些呢?很简单,就怕模型里面的一个函数引用了另外一个用户自定义的函数,然后这个自定义函数又引用另外一个,然后没玩没了。更怕的是,自定义函数里面还导入了一些非常大的数据,如果全部保存起来,model.pyt得多么大呀!

解决方案

一种方法当然就是把他的原项目下载下来,这包括了他的代码文件,而不能像我一样只下载模型文件。

其实,在load的时候Pytorch已经提示我们了,虽然只是一个warning。

FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don’t have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.

官方不建议torch.save(model,"model.pyt")这种保存方式,比较推荐torch.save(model.state_dict(),"model.pyt"),也就是只保存模型权重,其他的一律不保存。这样,强制了你去下载代码文件,不然下面第一行代码Model()会报错。

model = Model()
state_dict = torch.load('model.pyt')
model.load_state_dict(state_dict)#载入训练好的权重。

http://www.kler.cn/a/321499.html

相关文章:

  • 给阿里云OSS绑定域名并启用SSL
  • Docker在CentOS上的安装与配置
  • Javascript高级—常见算法
  • SQL面试题——蚂蚁SQL面试题 会话分组问题
  • 【深度学习】学习率介绍(torch.optim.lr_scheduler学习率调度策略介绍)
  • ABAP关于PS模块CJ20N中项目物料的屏幕和字段增强CI_RSADD
  • kafka分区和副本的关系?
  • 深度学习:ResNet残差神经网络
  • 【OpenSSL】OpenSSL 教程
  • C++ 数据类型分类
  • Android12的netd分析
  • 解析Vue2源码中的diff算法
  • kafka下载配置
  • 深度学习自编码器 - 得益于深度的指数增益篇
  • 数据集-目标检测系列-口罩检测数据集 mask>> DataBall
  • 自动驾驶综述 | 定位、感知、规划常见算法汇总
  • 网络编程(5)——模拟伪闭包实现连接的安全回收
  • GitLab发送邮件功能详解:如何配置自动化?
  • Bytebase 2.23.0 - 支持 Entra (Azure AD) 用户/组同步
  • 基于Node.js+Express+MySQL+VUE实现的计算机毕业设计共享单车管理网站
  • KVM 安装 Windows11
  • 不同的浏览器、服务器和规范对 URL 长度的限制
  • 【Gitee自动化测试0】日程
  • Vue3 取消密码输入框在浏览器中自动回填
  • 微信小程序配置prettier+eslint
  • JAVA实现Word(doc)文件读写