当前位置: 首页 > article >正文

神经网络——优化器

1.优化器介绍:

优化器集中在torch.optim中。

  • Constructing it

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)
  • Taking an optimization step
for input, target in dataset:
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

2.代码实战:

import torch
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader

dataset=torchvision.datasets.CIFAR10("data",train=False,transform=torchvision.transforms.ToTensor(),
                                     download=True)

#每个批次中加载的数据项数量
dataloader=DataLoader(dataset,batch_size=1)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()

        self.model1=Sequential(
            Conv2d(3,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,64,5,padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024,64),
            Linear(64,10)
        )

    def forward(self, x):
        x=self.model1(x)
        return x

loss=nn.CrossEntropyLoss()
tudui=Tudui()

optim=torch.optim.SGD(tudui.parameters(),lr=0.01)

for epoch in range(20):
    running_loss=0.0
    for data in dataloader:
        imgs,targets = data
        outputs =tudui(imgs)
        result_loss=loss(outputs,targets)
        #清零
        optim.zero_grad()
        result_loss.backward()
        #调优
        optim.step()
        running_loss=running_loss+result_loss
    print(running_loss)

在这里插入图片描述
后面loss又升高,为反向优化

3.总结:

优化器的基本使用

  • 如果要知道各个优化器的详细用法
  • 需要对其有一定了解
  • 注意要多训练几轮

http://www.kler.cn/a/281142.html

相关文章:

  • ffmpeg最新5.1.6版本源码安装
  • NLP从零开始------13.文本中阶序列处理之语言模型(1)
  • ODOO17文档打印(输出)方案 -- ODOO17 document printing (output) scheme
  • 前端算法题----任意子数组和的绝对值的最大值
  • 量化交易backtrader实践(四)_评价统计篇(3)_更多评价与可视化
  • openEuler安装Docker和踩坑分析
  • 单HTML文件集成Vue2+axios的使用
  • 解锁SQL的力量:SELECT COUNT()的计数艺术
  • Seata 的部署和集成
  • 服务降级的架构原理
  • intel cpu芯片的命名规则
  • 服务器远程管理
  • 使用Spring Cloud Consul实现微服务注册与发现的全面指南
  • 算法之二分查找法和双指针
  • [设计模式之抽象工厂模式—— 家具工厂]
  • 变压吸附制氧机在养殖产业的应用优势
  • 大模型学习应用 3: AutoDL 平台 transformers 环境搭建及模型部署使用(持续更新中)
  • 经验笔记:基于Token的身份认证及其安全性探讨
  • AIGC辅助办公
  • 儿童孤独症学校怎么选?