当前位置: 首页 > article >正文

PyTorch快速入门教程【小土堆】之损失函数与反向传播

视频地址损失函数与反向传播_哔哩哔哩_bilibili

Loss两个作用

1,计算实际输出和目标之间的差距

2. 为我们更新输出提供一定的依据(反向传播)

import torch
from torch import nn
from torch.nn import L1Loss

inputs = torch.tensor([1, 2, 3], dtype=torch.float32)
targets = torch.tensor([1, 2, 5], dtype=torch.float32)

inputs = torch.reshape(inputs, (1, 1, 1, 3))
targets = torch.reshape(targets, (1, 1, 1, 3))

loss = L1Loss(reduction='sum')
result = loss(inputs, targets)

loss_mse = nn.MSELoss()
result_mse = loss_mse(inputs, targets)

print(result)
print(result_mse)


x = torch.tensor([0.1, 0.2, 0.3])
y = torch.tensor([1])
x = torch.reshape(x, (1, 3))
loss_cross = nn.CrossEntropyLoss()
result_cross = loss_cross(x, y)
print (result_cross)

上方代码举例了几种loss函数的使用

下方代码是loss函数在模型中如何使用

import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("CIFAR10", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)

dataloader = DataLoader(dataset, batch_size=1)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


loss = nn.CrossEntropyLoss()
tudui = Tudui()
for data in dataloader:
    imgs, targets = data
    outputs = tudui(imgs)
    result_loss = loss(outputs, targets)
    result_loss.backward()


http://www.kler.cn/a/461669.html

相关文章:

  • C语言带参数的宏定义的相关知识汇总(最常用的形式、带标记分隔符##的形式...)
  • 初学STM32 ---高级定时器互补输出带死区控制
  • OpenNJet v3.2.0正式发布!
  • Dubbo扩展点加载机制
  • word文档中的文档网格——解决相同行间距当显示出不同行间距的情况
  • MATLAB画柱状图
  • 在 IntelliJ IDEA 中开发 GPT 自动补全插件
  • 【C语言程序设计——循环程序设计】求解最大公约数(头歌实践教学平台习题)【合集】
  • 【优选算法】Binary-Blade:二分查找的算法刃(上)
  • 动态规划五——回文串问题
  • Java后端常见问题 (一)jar:unknown was not found in alimaven
  • 一、Git与GitHub基础说明
  • 企业数字化转型的构念及实现路径
  • uniapp 打包apk
  • 基于深度学习的视觉检测小项目(一) 项目概况
  • 在Linux系统中配置邮件发送功能
  • SpringBoot使用TraceId日志链路追踪
  • 以EM算法为例介绍坐标上升(Coordinate Ascent)算法:中英双语
  • Elixir语言的函数实现
  • 打造汽车产线高效控制与降本增效新局面
  • C# 线程池的使用
  • 主线程,协程和互斥锁
  • java故障注入
  • 【机器人】机械臂:精度、重复精度、控制器分辨率、手腕、末端执行器
  • Jmeter的安装与使用
  • leetcode 热题100(131. 分割回文串)c++