当前位置: 首页 > article >正文

PyTorch快速入门教程【小土堆】之网络模型的保存和读取

视频地址网络模型的保存与读取_哔哩哔哩_bilibili

模型的保存

import torch
import torchvision
from torch import nn

vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1,模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

# #保存方式2,模型参数(官方推荐)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")


# 陷阱
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3)

    def forward(self, x):
        x = self.conv1(x)
        return x


tudui = Tudui()
torch.save(tudui, "tudui_method1.pthl")

模型的读取

import torch
import torchvision
from torch import nn

# 方式1-》保存方式1,加载模型
model = torch.load("vgg16_method1.pth")
print(model)

# 方式2,加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
# model = torch.load("vgg16_method2.pth")
print(vgg16)


# 陷阱1
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x


# 必须写出模型才能读取,但不需要实现这个模型
model = torch.load('tudui_method1.pth')
print(model)

http://www.kler.cn/a/460863.html

相关文章:

  • AcWing练习题:差
  • Linux-掉电保护方案
  • driftingblues2
  • 国产编辑器EverEdit - 常用资源汇总
  • 目标检测入门指南:从原理到实践
  • 碰一碰拓客系统:创新引领智能拓客新纪元
  • MAC系统QT Creator的快捷键
  • 运维人员的Python详细学习路线
  • JVM之Class文件详解
  • 【前端】Node.js使用教程
  • 《Vue进阶教程》第三十一课:ref的初步实现
  • 2025元旦源码免费送
  • 探索数据之美,Plotly引领可视化新风尚
  • 代码随想录算法训练营DAY17
  • Rust日志库tklog0.2.9—支持混合时间文件大小备份模式
  • windows下VS release调试
  • Stm32小实验1
  • 【GIS教程】高程点制作DEM并使用ArcgisPro发布高程服务Elevation Layer
  • win32汇编环境下,双击窗口程序内生成的listview列表控件的某行,并提取其内容的示例程序
  • Nmap实用语法简介
  • 使用WebRTC进行视频通信
  • 基于SC-FDE单载波频域均衡MQAM通信链路matlab仿真,包括帧同步,定时同步,载波同步,MMSE信道估计等
  • 在windows上使用vscode和cmake编译c++ 过程记录
  • git 中 工作目录 和 暂存区 的区别理解
  • 网络安全的8个热门趋势和4个渐冷趋势
  • 2 、什么是Java中的不可变类