当前位置: 首页 > article >正文

神经网络——网络模型

1.现有网络模型的使用及修改

1.1讲解VGG16网络模型:

import torchvision

vgg16_false=torchvision.models.vgg16(weights =None)
vgg16_true=torchvision.models.vgg16(weights='DEFAULT')
print(vgg16_true)

在这里插入图片描述

ImageNet数据集太大了,仅训练集就有147.9g。

1.2改变现有网络的参数:

  • 在现有模型中添加模型
vgg16_true.add_module('add_linear',nn.Linear(1000,10))
print(vgg16_true)

在这里插入图片描述

  • 添加模型至classifier中
vgg16_true.classifier.add_module('add_linear',nn.Linear(1000,10))
print(vgg16_true)

在这里插入图片描述

  • 修改模型
print(vgg16_false)
vgg16_false.classifier[6]=nn.Linear(4096,10)
print(vgg16_false)

在这里插入图片描述

2.网络模型的保存与读取

2.1模型的保存与读取方法1:

  • torch.save(实例, 保存名称)——model_save.py
  • torch.load(实例, 保存名称)——model_load.py

方法1:保存了模型结构+模型参数

#保存方式1
torch.save(vgg16,"vgg16_method1.pth")
#读取方式1:加载模型
model=torch.load("vgg16_method1.pth")#D:\py_code5\XTD\XTDProject_3\vgg16_method1.pth
print(model)

2.2模型的保存与读取方法2:

  • torch.save(实例.state_dict(), 保存名称) ——model_save.py
  • torch.load(实例.state_dict(), 保存名称)——model_load.py

方法2保存的是:模型参数(官方推荐),vgg16的网络模型状态保存为字典格式。不保存结构。

#保存方式2
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
#读取方式2,加载模型
model=torch.load("vgg16_method2.pth")

在这里插入图片描述

是字典格式,需要还原:

#方式二,加载模型
vgg16=torchvision.models.vgg16(weights = None)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
# model=torch.load("vgg16_method2.pth")
print(vgg16)

即可。

方式二的数据大小要小一点
在这里插入图片描述

2.3方法1的陷阱:

用方法1的时候一定要保证读取模型的文件里有定义该模型的类!

#陷阱——model_save.py
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1=nn.Conv2d(3,64,kernel_size=3)

    def forward(self, x):
        x=self.conv1(x)
        return x

tudui=Tudui()
torch.save(tudui,"tudui_method1.pth")
#陷阱——model_load.py
model=torch.load("tudui_method1.pth")
print(model)

在这里插入图片描述

需要让模型能访问到定义的class

  • 方法一:将class的定义放入model_load.py中
#陷阱
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1=nn.Conv2d(3,64,kernel_size=3)

    def forward(self, x):
        x=self.conv1(x)
        return x

model=torch.load("tudui_method1.pth")
print(model)
  • 方法二:引入model_save
from model_save import *

在这里插入图片描述


http://www.kler.cn/a/281010.html

相关文章:

  • 小琳AI课堂:损失函数
  • 奥斯汀玫瑰:独特起源、惊艳形态与深刻花语探秘
  • 安捷伦色谱仪器LabVIEW软件替换与禁运配件开发
  • three.js渲染中文的3D字体
  • SpringBoot集成kafka-监听器注解
  • C#实现数据采集系统-数据反写(2)消息内容处理和写入通信类队列
  • FL Studio24苹果mac电脑破解绿色版安装包下载
  • pyinstaller pyqt5 pytest打包后报错no module unittest.mock
  • polarctf靶场[WEB]Don‘t touch me、机器人、uploader、扫扫看
  • NLP从零开始------15.文本中阶序列处理之语言模型(3)
  • anaconda的power shell和prompt有什么区别?
  • 使用dx工具将jar和class打包成dex
  • 高级问题解决查询搜索网址
  • CMake构建学习笔记10-OsgQt库的构建
  • 《黑神话悟空》:国产3A游戏的崛起与AI绘画技术的融合
  • 【Linux】CodeServer:云IDE部署
  • 使用 ASP.NET Core 与 Entity Framework Core 进行数据库操作
  • 【图像】灰度图与RGB图像的窗宽、窗位的值范围二三问
  • 在VBA中,对Excel单元格的操作方法 (qo+op)
  • 报表融合大屏,做不一样的财务分析!