当前位置: 首页 > article >正文

【深度学习】(8)--神经网络使用最优模型

文章目录

  • 使用最优模型
    • 直接使用最优模型的两种方法
      • 一、 定义模型结构
      • 二、 加载模型
        • 1. 加载模型状态字典
        • 2. 加载完整模型
      • 三、设置设备
      • 四、使用模型预测
  • 总结

使用最优模型

直接使用最优模型在多个方面都具有显著的好处,尤其是在深度学习和机器学习领域。以下是一些主要的好处:

  1. 节省时间和资源
    • 训练时间:训练一个深度学习模型可能需要数小时、数天甚至数周的时间,特别是当数据集很大或模型很复杂时。直接使用最优模型可以立即开始使用,无需等待长时间的训练过程。
    • 计算资源:训练模型需要大量的计算资源,包括高性能的GPU。直接使用最优模型可以显著减少对这些资源的需求。
  2. 提高性能
    • 更好的泛化能力:最优模型通常是在大型、多样化的数据集上训练的,因此它们能够更好地泛化到新的、未见过的数据上。
    • 调优:许多最优模型已经过仔细的调优,包括超参数调整和架构搜索,以确保它们在特定任务上表现最佳。
  3. 易于实现
    • 快速原型开发:对于研究人员或开发人员来说,使用最优模型可以快速实现原型,以测试想法或验证假设。
    • 减少复杂性:直接使用最优模型可以减少实现和调试新模型的复杂性,特别是在模型架构和数据预处理方面。

直接使用最优模型的两种方法

一、 定义模型结构

首先,你需要有定义模型结构的代码。这通常是一个继承自torch.nn.Module的类。如果你没有保存整个模型实例,而是只保存了模型的状态字典(即模型的参数和缓冲区),那么你需要重新定义模型结构。

注意:定义模型的将结果,务必要与使用的最优模型结构相同,否则参数不匹配。

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1 = nn.Sequential( # 将多个层组合在一起
            nn.Conv2d(         # 2d一般用于图像,3d用于视频数据(多一个时间维度),1d一般用于结构化的序列数据
                in_channels=3, # 图像通道个数,1表示灰度图(确定卷积核 组中的个数)
                out_channels=16, # 要得到多少特征图,卷积核的个数
                kernel_size=5,  # 卷积核大小
                stride=1,   # 步长
                padding=2   # 边界填充大小
            ),
            nn.ReLU(), # relu层,不会改变特征图的大小
            nn.MaxPool2d(kernel_size=2) # 进行池化操作(2*2区域)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16,32,5,1,2),
            nn.ReLU(),
            nn.Conv2d(32,32,5,1,2),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(32,128,5,1,2), # 输出(64,7,7)
            nn.ReLU()
        )
        self.out = nn.Linear(128*54*54,4)

    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x) # 输出(64,7,7)
        x = x.view(x.size(0),-1) # flatten 操作,结果为:(batch_size,64*7*7)
        output = self.out(x)
        output = torch.sigmoid(output)
        return output

二、 加载模型

加载模型有两种方法:1.加载模型状态字典;2.加载整个模型实例

如果你保存的是整个模型实例(使用torch.save(model, ‘model_path.pt’)),你可以直接加载它。但是,更常见的是只保存和加载模型的状态字典(使用torch.save(model.state_dict(), ‘model_state_dict.pt’))。

1. 加载模型状态字典
# 1. 读取参数的方法:
model = CNN().to(device)
model.load_state_dict(torch.load("best.pth"))
model.eval()  # 设置为评估模式,固定模型参数和数据,防止后面被修改
2. 加载完整模型

虽然不需要与保存时完全相同的网络结构代码,但你的环境中必须存在与保存模型时相同的模型类定义。这是因为加载模型时,PyTorch需要知道如何将加载的数据(即模型的参数和结构)映射回Python中的类实例。

因此,该方法也需要提前定义模型。

# 2. 读取完整模型的方法,需要提前创建model:
model = CNN().to(device)
model = torch.load("best1.pt")
model.eval()    #固定模型参数和数据,防止后面被修改

三、设置设备

"""---判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU"""
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

四、使用模型预测

result = []#保存的预测的结果
lables = []#真实结果
def test_true(dataloader, model):
    with torch.no_grad():   #一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候。这可以减少计算所用内存消耗。
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model.forward(X)#预测之后的结果。
            result.append(pred.argmax(1).item())
            lables.append(y.item())

test_data = food_dataset(file_path='testda.txt',transform=data_transforms['valid'])
test_dataloader = DataLoader(test_data,batch_size=1,shuffle=True)

test_true(test_dataloader,model)
print(f"预测值 \t:{result}")
print(f"真实值 \t:{lables}")
  • 梯度计算:在评估模式下,使用**torch.no_grad()**上下文管理器可以减少内存消耗并加速计算,因为不需要存储用于反向传播的梯度。

总结

本篇介绍了:

  1. 如何使用保存好的最优模型
  2. 加载模型的两种方法:1.加载模型状态字典;2.加载整个模型实例
  3. 注意:两种方法都需要提前定义神经网络结构。

http://www.kler.cn/news/326569.html

相关文章:

  • js统计字符串中每个字符出现的次数
  • Python | Leetcode Python题解之第440题字典序的第K小数字
  • 【DP解密多重背包问题】:优化策略与实现
  • Iptables,ufw,firewalld的关系与区别
  • 大语言模型(LLM)的子模块拆拆分进行联邦学习;大语言模型按照多头(Multi-Head)拆分进行联邦学习
  • pdf转换成word有哪些方法?10种将PDF转成word的方法
  • 搜维尔科技:5DT数据手套超高的数据质量、较低的交叉关联、高数据频率
  • VUE 开发——AJAX学习(一)
  • 群晖安装Audiobookshelf(有声书)
  • YOLOv11改进 | Neck篇 | YOLOv11引入BiFPN双向特征金字塔网络
  • 项目管理专业资质认证ICB 3中关于项目经理素质的标准
  • FreeRTOS调度器与任务
  • HTML初认识 -- 第二课(全网最好的入门课)
  • el-cascader懒加载回显问题
  • 这 5 个自动化运维场景,可能用 Python 更香?
  • 【工程测试技术】第3章 测试装置的基本特性,静态特性和动态特性,一阶二阶系统的特性,负载效应,抗干扰性
  • Python知识点:如何使用Flink与Python进行实时数据处理
  • Docker快速搭建PostgreSQL15流复制集群
  • 端模一体,猎豹移动对大模型机器人发展路径清晰
  • 每天认识几个maven依赖(ant)
  • dea插件开发-自定义语言9-Rename Refactoring
  • 【以图搜图代码实现2】--faiss工具实现犬类以图搜图
  • mips指令系统简介
  • AI大模型面试大纲
  • 基于单片机的催眠电路控制系统
  • [云服务器15] 全网最全!手把手搭建discourse论坛,100%完成
  • 什么是 Apache Ingress
  • 钉钉H5微应用Springboot+Vue开发分享
  • win11下 keil报错Cannot load driver ‘D:\Keil_v5\ARM\Segger\JL2CM3.dll‘
  • WAF,全称Web Application Firewall,好用WAF推荐