当前位置: 首页 > article >正文

PyTorch快速入门教程【小土堆】之优化器

视频地址优化器(一)_哔哩哔哩_bilibili

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("CIFAR10", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)

dataloader = DataLoader(dataset, batch_size=1)


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


loss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        result_loss = loss(outputs, targets)
        optim.zero_grad()
        result_loss.backward()
        optim.step()
        running_loss = running_loss + result_loss.item()
    print(running_loss)


http://www.kler.cn/a/467974.html

相关文章:

  • GitLab集成Runner详细版--及注意事项汇总【最佳实践】
  • 练习(继承)
  • ceph集群配置
  • 【前端知识】手搓微信小程序
  • linux-centos-安装miniconda3
  • 软件工程期末大复习(六)面向对象分析
  • STM32 拓展 低功耗案例3:待机模式 (register)
  • SZY206-2016水资源监测数据传输规约 基础架构
  • 深入解析 Redisson 分布式限流器 RRateLimiter 的原理与实现
  • python识别outlook邮件并且将pdf文件作为附件发送邮件
  • 矩阵运算提速——玩转opencv::Mat
  • 电脑键盘打不了字是何原因,如何解决呢
  • 软件需求规格是什么
  • CSS——4. 行内样式和内部样式(即CSS引入方式)
  • 连接Milvus
  • Apache PDFBox添加maven依赖,pdf转成图片
  • 人工智能(AI)简史:推动新时代的科技力量
  • 详解MySQL SQL删除(超详,7K,含实例与分析)
  • PaperAssistant:使用Microsoft.Extensions.AI实现
  • Uniapp中使用`wxml-to-canvas`开发DOM生成图片功能
  • Traceroute 网络诊断工具实战详解
  • 中高级运维工程师运维面试题(九)之 Apache Pulsar
  • MySQL优化器估算SQL语句访问行数的深入分析
  • MIPI_DPU 综合(DPU+MIPI+Demosaic+VDMA 通路)
  • Django Admin中实现字段自动提交功能
  • 文献分享:跨模态的最邻近查询RoarGraph