当前位置: 首页 > article >正文

《深度学习》PyTorch 常用损失函数原理、用法解析

目录

一、常用损失函数

1、CrossEntropyLoss(交叉熵损失)

        1)原理

        2)流程

        3)用法示例

2、L1Loss(L1损失/平均绝对误差)

        1)原理

        2)用法示例

3、NLLLoss(负对数似然损失)

        1)原理

        2)用法示例

4、 MSELoss(均方误差损失)

        1)定义

        2)用法示例

5. BCELoss(二元交叉熵损失)

        1)定义

        2)用法示例

二、总结常用损失函数

        1、nn.CrossEntropyLoss:交叉熵损失函数

        2、nn.MSELoss:均方误差损失函数

        3、nn.L1Loss:平均绝对误差损失函数

        4、nn.BCELoss:二元交叉熵损失函数

        5、nn.NLLLoss:负对数似然损失函数


一、常用损失函数

1、CrossEntropyLoss(交叉熵损失)

        1)原理

                交叉熵损失是一种常用于分类问题的损失函数,它衡量的是模型输出的概率分布与真实标签分布之间的差异

                在多分类问题中,模型会输出每个类别的预测概率。交叉熵损失通过计算真实标签对应类别的负对数概率评估模型的性能。在实际应用中,nn.CrossEntropyLoss内部会对logits(即未经softmax的原始输出)应用softmax函数,将其转换为概率分布,然后计算交叉熵。

                例如:

                        假设有一个多类别分类任务,共有C个类别。对于每个样本,模型会输出一个包含C个元素的向量,其中每个元素表示该样本属于对应类别的概率。而真实标签是一个C维的向量,其中只有一个元素为1,其余元素均为0,表示样本的真实类别。

        2)流程

                首先,将模型输出的向量通过softmax函数进行归一化,将原始的概率值转换为概率分布。即对模型输出的每个元素进行指数运算,然后对所有元素求和,最后将每个元素除以总和,得到归一化后的概率分布。

                然后,将归一化后的概率分布与真实标签进行比较,计算两者之间的差异。交叉熵损失函数的计算公式为: -sum(y * log(p))  ,其中y是真实标签的概率分布,p是模型输出的归一化后的概率分布。该公式表示真实标签的概率分布与模型输出的归一化后的概率分布之间的交叉熵。

                最后,将每个样本的交叉熵损失值进行求和或平均,得到整个批次的损失值。

       

        3)用法示例
import torch  
import torch.nn as nn  
  
# 假设有一个模型输出的logits和一个真实的标签  
logits = torch.randn(10, 5, requires_grad=True)  # 10个样本,5个类别  
labels = torch.randint(0, 5, (10,))  # 真实标签,每个样本对应一个类别索引  
  
# 创建CrossEntropyLoss实例  
loss_fn = nn.CrossEntropyLoss()  
  
# 计算损失  
loss = loss_fn(logits, labels)  
  
# 反向传播  
loss.backward()

2、L1Loss(L1损失/平均绝对误差)

        1)原理

                L1损失,也称为平均绝对误差(MAE),计算的是预测值与真实值之差绝对值平均值

                L1损失对异常值(即远离平均值的点)的敏感度较低,因为它通过绝对值来度量误差,而绝对值函数在零点附近是线性的。

       

        2)用法示例
loss_fn = nn.L1Loss()  
predictions = torch.randn(3, 5, requires_grad=True)  # 预测值  
targets = torch.randn(3, 5)  # 真实值  
  
# 计算损失  
loss = loss_fn(predictions, targets)  
  
# 反向传播  
loss.backward()

3、NLLLoss(负对数似然损失)

        1)原理

                负对数似然损失(NLLLoss)通常与log_softmax一起使用,用于多分类问题。它计算的是目标类别负对数概率

                NLLLoss期望的输入是对数概率(即已经通过log_softmax处理过的输出),然后计算目标类别的负对数概率。

        2)用法示例
# 假设已经计算了logits  
logits = torch.randn(3, 5, requires_grad=True)  
  
# 应用log_softmax获取对数概率(在PyTorch中,通常直接使用CrossEntropyLoss)  
log_probs = torch.log_softmax(logits, dim=1)  
  
# 创建NLLLoss实例  
loss_fn = nn.NLLLoss()  
  
# 真实标签  
labels = torch.tensor([1, 0, 4], dtype=torch.long)  
  
# 计算损失  
loss = loss_fn(log_probs, labels)  
  
# 反向传播  
loss.backward()

                在实际应用中,直接使用CrossEntropyLoss更为常见,因为它内部集成了softmax和NLLLoss的计算。

4、 MSELoss(均方误差损失)

        1)定义

                均方误差损失(MSE)计算的是预测值与真实值之差的平方的平均值

                MSE通过平方误差来放大较大的误差,从而给予模型更大的惩罚。它是回归问题中最常用的损失函数之一。

        2)用法示例
loss_fn = nn.MSELoss()  
predictions = torch.randn(3, 5, requires_grad=True)  # 预测值  
targets = torch.randn(3, 5)  # 真实值  
  
# 计算损失  
loss = loss_fn(predictions, targets)  
  
# 反向传播  
loss.backward()

5.BCELoss(二元交叉熵损失)

        1)定义

                二元交叉熵损失(BCE)用于二分类问题,计算的是预测概率与真实标签(0或1)之间的交叉熵

                BCE通过计算真实标签对应类别的负对数概率来评估模型的性能。它适用于输出概率的模型,但并不要求输入必须经过sigmoid函数(尽管在实践中很常见)。

        2)用法示例
loss_fn = nn.BCELoss()  
  
# 假设预测值已经通过sigmoid函数(虽然不是必需的)  
predictions = torch.sigmoid(torch.randn(3, requires_grad=True))  
  
# 真实标签  
targets = torch.empty(3).random_(2).float()  # 生成0或1的随机值  
  
# 计算损失  
loss = loss_fn(predictions, targets)  
  
# 反向传播  
loss.backward()

二、总结常用损失函数

        1、nn.CrossEntropyLoss:交叉熵损失函数

                主要用于多分类问题。它将模型的输出(logits)与真实标签进行比较,并计算损失。

        2、nn.MSELoss:均方误差损失函数

                用于回归问题。它计算模型输出与真实标签之间的差异的平方,并返回平均值。

        3、nn.L1Loss:平均绝对误差损失函数

                也称为L1损失。类似于MSELoss,但是它计算模型输出与真实标签之间的差异的绝对值,并返回平均值。

        4、nn.BCELoss:二元交叉熵损失函数

                用于二分类问题。它计算二分类问题中的模型输出与真实标签之间的差异,并返回损失。

        5、nn.NLLLoss:负对数似然损失函数

                主要用于多分类问题。它首先应用log_softmax函数(log_softmax(x) = log(softmax(x)))将模型输出转化为对数概率,然后计算模型输出与真实标签之间的差异。


http://www.kler.cn/a/314993.html

相关文章:

  • 数据仓库中的指标体系模型介绍
  • PHP Array:精通数组操作
  • 关于 webservice 日志中 源IP是node IP的问题,是否能解决换成 真实的客户端IP呢
  • js -音频变音(听不出说话的人是谁)
  • 图像分割基础:使用Python和scikit-image库
  • ThreadPoolExecutor keepAliveTime 含义
  • 【电力系统】基于遗传算法的33节点电力系统无功优化及MATLAB实现
  • LeetCode337. 打家劫舍III
  • springbootKPL比赛网上售票系统
  • Maven 项目无法下载某个依赖
  • 论 JAVA 集合框架中 接口与类的关系
  • 注册信息安全专业人员(CISP)和网络安全的联系与区别
  • FLStudio21Mac版flstudio v21.2.1.3430简体中文版下载(含Win/Mac)
  • windows cuda12.1 pytorch gpu环境配置
  • js之遍历方法
  • windows@文件系统链接@快捷方式@快捷键方式和符号链接及其对比
  • 本地提权【笔记总结】
  • 《AI:开启未来的无限可能》
  • 【django】局域网访问django启动的项目
  • MongoDB解说
  • 机器人速度雅可比矩阵(机器人动力学)
  • 自动化立体仓库与堆垛机单元的技术参数
  • 设计模式之结构型模式例题
  • 简单题35-搜索插入位置(Java and Python)20240919
  • 如何使用 C# 解决 Cloudflare Turnstile CAPTCHA 挑战
  • Flyway 基本概念