当前位置: 首页 > article >正文

Pytorch学习--神经网络--网络模型的保存与读取

一、网络模型的保存与读取方式1

方法讲解

在这里插入图片描述
在这里插入图片描述

保存模型

import torch
import torchvision
model = torchvision.models.vgg16(weights='DEFAULT')
#保存模型和参数
torch.save(model,"save_method1.pth")

读取模型

import torch
model = torch.load("save_method1.pth")
print(model)

输出:在这里插入图片描述

比较坑人的点

使用 torch.save 必须将该模型的架构引入到该文件中(可以使用from A import B的方式来解决),这里举一个例子来说明

保存模型

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

#保存模型和参数

class Mary(nn.Module):
    def __init__(self):
        super(Mary,self).__init__()
        self.model1 = nn.Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )
    def forward(self,x):
        x = self.model1(x)
        return x
Yorelee = Mary()
torch.save(Yorelee,"save_method1_question.pth")

读取模型

import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

model = torch.load("save_method1_question.pth")

print(model)

报错如下

在这里插入图片描述
说明我们还要把 Mary 这个框架复制到读取模型的.py文件中

重新更正后的读取模型代码

import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

class Mary(nn.Module):
    def __init__(self):
        super(Mary,self).__init__()
        self.model1 = nn.Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )
    def forward(self,x):
        x = self.model1(x)
        return x

model = torch.load("save_method1_question.pth")

print(model)
或者
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
from torch_save import Mary   #这里仅举一个例子


model = torch.load("save_method1_question.pth")

print(model)

二、网络模型的保存与读取方式2

保存模型参数

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear


vgg_model = torchvision.models.vgg16(weights='DEFAULT')
#保存参数
torch.save(vgg_model.state_dict(),"save_method2.pth")

读取模型参数

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

vgg_model = torchvision.models.vgg16(weights='DEFAULT')
parameter = torch.load("save_method2.pth")
vgg_model.load_state_dict(parameter)
print(vgg_model)

http://www.kler.cn/a/388229.html

相关文章:

  • 十四、Vue 混入(Mixins)详解
  • 《Vue3实战教程》19:Vue3组件 v-model
  • 欧科云链研究院:ChatGPT 眼中的 Web3
  • C语言:枚举类型
  • 计算机网络 (27)IP多播
  • 云架构Web端的工业MES系统设计之区分工业过程
  • Java毕业设计-----基于AIGC的智能客服系统
  • [LInux] 进程地址空间
  • Android 14 SPRD 下拉菜单中增加自动亮度调节按钮
  • 鸿蒙系统:智能设备新时代的技术驱动
  • MySQL:数据类型建表
  • system generator结合高版本matlab的使用
  • 【Linux】进程创建/等待/替换相关知识详细梳理
  • 查缺补漏----用户上网过程(HTTP,DNS与ARP)
  • 信息安全工程师(79)网络安全测评概况
  • 架构师备考-架构基本概念
  • Diving into the STM32 HAL-----DMA笔记
  • 【科普小白】LLM大语言模型的基本原理
  • 《Linux运维总结:基于银河麒麟V10+ARM64架构CPU部署redis 6.2.14 TLS/SSL哨兵集群》
  • Ubuntu学习笔记 - Day3
  • excel常用技能
  • C++ | 表示移动函数move()的基本用法
  • 【Golang】Go语言教程
  • 【leetcode练习·二叉树】用「分解问题」思维解题 I
  • mysql 配置文件 my.cnf 增加 lower_case_table_names = 1 服务启动不了
  • 【前端】JavaScript 方法速查大全-DOM、BOM、时间、处理JS原生问题(三)