当前位置: 首页 > article >正文

网络模型的保存与读取

文章目录

    • 一、模型的保存
    • 二、文件的加载
    • 三、模型加载时容易犯的陷阱

一、模型的保存

方式1:torch.save(vgg16, “vgg16_method1.pth”)

import torch
import torchvision.models

vgg16 = torchvision.models.vgg16(pretrained=False)
torch.save(vgg16, "vgg16_method1.pth")

如果运行报错:UserWarning: Arguments other than a weight enum or None for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing weights=None或者The parameter ‘pretrained‘ is deprecated since 0.13 and may be removed in the future

原因是在 PyTorch 的 torchvision 库中,从版本 0.13 开始,pretrained 参数已经被弃用,取而代之的是 weights 参数。这个改变是为了提供更丰富的预训练模型选择。当你尝试使用 vgg16(pretrained=False) 时,你收到了一个警告,告诉你 pretrained 参数已经不再被使用,并且建议你使用 weights 参数。

要解决这个问题,你应该使用 weights 参数来代替 pretrained。

修正代码:

import torch
import torchvision.models

vgg16 = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
torch.save(vgg16, "vgg16_method1.pth")

运行代码:
在这里插入图片描述
可以看到多了一个新文件vgg16_method1.pth

该方式1保存的网路模型不仅保存了网络模型的一种结构,它也保存了模型当中的一些参数

方式2:把模型的参数保存成字典(dict)形式

import torch
import torchvision.models

vgg16 = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
torch.save(vgg16.state_dict(),"vgg16_method2.pth")

运行结果:
在这里插入图片描述
方式1与方式2对比:

方式1保存的是模型的结构+模型的参数,方式2保存的只是模型的参数(官方推荐的保存方式)

官方推荐的原因是当保存一个大的模型时候,方式2所用的空间更小

我们可以查看一下两种保存方式的文件大小:
在这里插入图片描述
因为vgg这个模型本身就不大,所以文件大小差距并不明显,但方式2足足小了7kb!要是在大模型下节省空间这点会尤其明显。

二、文件的加载

代码:

import torch

model = torch.load("vgg16_method1.pth")
print(model)

运行结果:
在这里插入图片描述

通过将save与load的文件debug运行:
在这里插入图片描述
能够发现两者都是一样的,说明被完整加载出来。

通过上述步骤可以看到模型中的参数也一同保存下来了。

加载方式2保存的模型:

import torch

model = torch.load("vgg16_method2.pth")
print(model)

运行结果:
在这里插入图片描述
可以看到方式2形式是一个个字典形式.

方式2从字典形式想要恢复网络模型结构则需要:

import torch
import torchvision

# 创建一个VGG16模型实例,参数pretrained=False表示不加载预训练的权重。
vgg16 = torchvision.models.vgg16(pretrained=False)

# 加载之前保存的模型权重,这些权重保存在名为"vgg16_method2.pth"的文件中。
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))

# 打印出模型的结构,这样可以看到模型的各个层和参数。
print(vgg16)

运行结果:
在这里插入图片描述
可以看到把模型参数成功加载出来了

三、模型加载时容易犯的陷阱

保存一个自己写的网络模型:

import torch
import torchvision.models
from torch import nn


class Sen(nn.Module):
    def __init__(self):
        super(Sen, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x

sen = Sen()
torch.save(sen, "sen_method1.pth")

运行结果:
在这里插入图片描述

因为用方式1进行保存,故使用方式1的方法进行加载:

import torch
import torchvision

model = torch.load("sen_method1.pth")
print(model)

运行结果:
在这里插入图片描述
可以看到发生了报错,报错的意思是加载的时候没有找到Sen这个类

解决方法是将类复制到加载代码中:

import torch
import torchvision

class Sen(nn.Module):
    def __init__(self):
        super(Sen, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x

model = torch.load("sen_method1.pth")
print(model)

注意不需要写sen = Sen()这一代码

运行代码:
在这里插入图片描述

也就是用自己写的网络模型不同于现有的网络模型,需要进行导入才能正常加载出来!

或者也可以用import的方法加载自己写的网络模型,那么就不需要老是复制粘贴

通过from model_save import *加载:

import torch
from model_save import *

model = torch.load("sen_method1.pth")
print(model)

运行结果:
在这里插入图片描述


http://www.kler.cn/news/321186.html

相关文章:

  • Testbench编写与Vivado Simulator的基本操作
  • 如何快速免费搭建自己的Docker私有镜像源来解决Docker无法拉取镜像的问题(搭建私有镜像源解决群晖Docker获取注册表失败的问题)
  • 解决SVN蓝色问号的问题
  • 线性基学习DAY2
  • Kafka 面试题
  • 一个证明-待验证
  • 平衡、软技能与持续学习
  • pdf编辑转换器怎么用?分享9个pdf编辑、转换方法(纯干货)
  • 基于深度学习的药品三期OCR字符识别
  • 生成式语言模型底层技术面试
  • 修改Docker默认存储路径,解决系统盘占用90%+问题(修改docker root dir)
  • 【笔记】数据结构|链表算法总结|快慢指针场景和解决方案|链表归并算法和插入算法|2012 42
  • 共享单车轨迹数据分析:以厦门市共享单车数据为例(八)
  • 爬虫过程 | 蜘蛛程序爬取数据流程(初学者适用)
  • P335_0334韩顺平Java_零钱通介绍
  • 华为NAT ALG技术的实现
  • AttributeError: ‘Sequential‘ object has no attribute ‘predict_classes‘如何解决
  • 【Python报错已解决】ModuleNotFoundError: No module named ‘psutil’
  • Android——运行时动态申请权限
  • [Redis][Hash]详细讲解
  • 828华为云征文 | 在华为云X实例上部署微服务架构的文物大数据管理平台的实践
  • linux命令:显示已安装在linux内核的模块的详细信息的工具modinfo详解
  • 物理学基础精解【7】
  • Docker 容器技术:颠覆传统,重塑软件世界的新势力
  • 【RAG研究1】导言-我打算如何对RAG进行全面且深入的研究
  • 【后端开发】JavaEE初阶——计算机是如何工作的???
  • 职业技能大赛-单元测试笔记(assertThat)分享
  • [SDX35]SDX35如何查看GPIO的Base值
  • 力扣随机一题——所有元音按顺序排序的最长字符串
  • Linux嵌入式驱动开发指南(速记版)---Linux基础篇