当前位置: 首页 > article >正文

《深度学习》—— PyTorch的神经网络模块中常用的损失函数

文章目录

  • 前言
  • 一、回归模型中常用的损失函数
    • 1、平均绝对误差损失(L1Loss)
    • 2、均方误差损失(MSELoss也称L2Loss)
    • 3、SmoothL1Loss
  • 二、分类模型中常用的损失函数
    • 1、负对数似然损失(NLLLoss)
    • 2、二元交叉熵损失(BCELoss)
    • 3、交叉熵损失(CrossEntropyLoss)

前言

  • 在神经网络中,损失函数(Loss Function)扮演着至关重要的角色,它是衡量模型预测值实际值之间差异的关键工具

  • 损失函数,顾名思义,是用来量化模型预测结果“损失”或“误差”大小的函数。在神经网络训练过程中,损失函数通过将预测值与实际值进行比较,计算出一个数值(通常为非负实数),该数值反映了模型在当前状态下的预测性能。这个数值越小,表示模型的预测结果越接近真实值,即模型的性能越好。

  • 在PyTorch的nn(神经网络模块)中提供了很多的损失函数,如下:

    from torch import nn  # 导入神经网络模块
    __all__ = ['L1Loss', 'NLLLoss', 'NLLLoss2d', 'PoissonNLLLoss', 'GaussianNLLLoss', 'KLDivLoss',
               'MSELoss', 'BCELoss', 'BCEWithLogitsLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss',
               'SmoothL1Loss', 'HuberLoss', 'SoftMarginLoss', 'CrossEntropyLoss', 'MultiLabelSoftMarginLoss',
               'CosineEmbeddingLoss', 'MarginRankingLoss', 'MultiMarginLoss', 'TripletMarginLoss',
               'TripletMarginWithDistanceLoss', 'CTCLoss']
    # 调用方法 ——> nn.损失函数名
    

一、回归模型中常用的损失函数

1、平均绝对误差损失(L1Loss)

  • L1 Loss,也称为平均绝对误差(Mean Absolute Error, MAE),是深度学习中常用的一种损失函数,主要用于衡量模型预测结果与真实标签之间的平均绝对误差。以下是关于L1 Loss的详细解释:
  • 定义与公式
    • L1 Loss定义为模型预测值与真实值之间差的绝对值的平均。对于一个大小为 N 的样本集合,L1 Loss的公式如下:
      在这里插入图片描述
  • 特点与优势
    • 1.鲁棒性:与 L2 Loss(均方误差)相比,L1 Loss对于异常值(outliers)的容忍性更高,因此在某些需要考虑异常值的任务中,如目标检测、人脸识别等领域,L1 Loss被广泛应用。
    • 2.稳定梯度:L1 Loss的梯度在误差较小时较为稳定,不会因误差的微小变化而导致梯度的大幅波动,这有助于模型的稳定训练。
    • 3.稀疏性:在某些情况下,L1 Loss更容易得到稀疏解,即更多的预测值为0,这在特征选择等任务中可能具有优势。
  • 应用场景:L1 Loss主要用于回归任务中,是监督学习中衡量模型预测性能的重要指标之一

2、均方误差损失(MSELoss也称L2Loss)

  • 定义与公式

    • MSELoss通过计算预测值与真实值之间差异的平方的平均值来衡量模型的预测性能。其数学表达式为:
      在这里插入图片描述
  • 优点与局限性

  • 优点:

    • 1.简单直观:MSELoss的计算方式简单,易于理解和实现。
    • 2.优化友好:在梯度下降等优化算法中,MSELoss的梯度计算也相对简单,有助于模型的快速收敛。
  • 局限性:

    • 1.对异常值敏感:由于MSE Loss是平方误差,因此当数据中存在异常值时,这些异常值可能会对损失函数产生较大的影响,从而导致模型性能下降。
    • 2.不考虑误差方向:MSELoss只考虑误差的大小,而不考虑误差的方向,这可能导致在某些情况下模型的预测结果偏离真实值的方向。
  • 应用场景:MSELoss因其简单直观的特点,在多种回归问题中得到了广泛应用,例如:房价预测、股票价格预测、时间序列预测等

3、SmoothL1Loss

  • SmoothL1Loss是一种在深度学习中常用的损失函数,特别是在目标回归问题中。它结合了L1 Loss和L2 Loss的优点,提供了一种平滑的损失度量方式。
  • SmoothL1Loss的公式过于复杂,只需要记住它是结合了L1 Loss和L2 Loss两个损失函数
  • 特点与优势
    • 1.平滑性:SmoothL1Loss在差异较小时采用L2 Loss的平方形式,这使得损失函数在这一点附近更加平滑,有助于模型的稳定训练。
    • 2.鲁棒性:当差异较大时,SmoothL1Loss采用L1 Loss的形式,这有助于减少异常值(离群点)对模型训练的影响,使得模型更加鲁棒。
    • 3.平衡性:SmoothL1Loss通过结合L1 Loss和L2 Loss的优点,既能在差异较小时提供稳定的梯度(如L2 Loss),又能在差异较大时保持较小的梯度(如L1 Loss),从而在两者之间取得平衡。
  • 应用场景
    • SmoothL1Loss广泛应用于目标检测边界框回归等任务中。在这些任务中,预测值和真实值之间的差异可能很大,同时需要模型对异常值具有一定的鲁棒性。SmoothL1Loss正好满足了这些需求,因此在这些领域得到了广泛应用。

二、分类模型中常用的损失函数

1、负对数似然损失(NLLLoss)

  • 定义与公式
    • NLLLoss用于度量模型预测的类别概率分布与实际类别分布之间的差距。对于每个样本,NLLLoss计算预测的类别的概率与真实类别的概率的负对数似然值,最终将所有样本的损失值求平均得到总损失。其计算公式如下:
      在这里插入图片描述
  • 特点与优势
    • 直观性:NLLLoss直接反映了模型预测的概率分布与实际分布之间的差异,损失值越小,说明模型预测越准确。
    • 适用性:NLLLoss特别适用于多分类问题,能够很好地衡量模型在多个类别上的预测性能。
    • 灵活性:在PyTorch等深度学习框架中,NLLLoss可以通过设置不同的参数来适应不同的需求,如是否对损失值进行平均或求和等。
  • 应用场景:NLLLoss广泛应用于各种深度学习任务中,特别是多分类问题

2、二元交叉熵损失(BCELoss)

  • 定义与公式
    • BCELoss通过计算模型输出的概率分布与实际标签之间的交叉熵损失来评估模型的性能。对于每个样本,BCELoss的公式如下:
      在这里插入图片描述
  • 特点与优势
    • 适用于二分类问题:BCELoss特别适用于处理只有两个类别的分类问题,如文本情感分析(积极/消极)、病患诊断(患病/未患病)等。
    • 直接优化分类准确率:BCELoss直接衡量预测概率与真实标签之间的差异,通过最小化该差异来优化模型的分类性能。
    • 平滑且易于求导:BCELoss的函数曲线平滑,求导过程相对简单,有利于在训练过程中使用梯度下降等优化算法进行参数更新。
  • 应用场景:BCELoss广泛应用于各类二分类任务中,例如:
    • 文本分类:如情感分析、垃圾邮件检测等。
    • 图像识别:如医学影像分析中的病患诊断。
    • 生物信息学:如基因表达数据的分类等。

3、交叉熵损失(CrossEntropyLoss)

  • 定义与公式

    • 对于多分类问题,假设有C个类别,CrossEntropyLoss的公式可以表示为:
      在这里插入图片描述
  • 特点与优势

    • 1.直观性:交叉熵损失函数直接衡量了预测概率分布与真实概率分布之间的差异,具有直观性。
    • 2.优化友好:交叉熵损失函数的梯度计算相对简单,有助于在训练过程中使用梯度下降等优化算法进行参数更新。
    • 3.适用于多分类问题:交叉熵损失函数是处理多分类问题的标准方法,与softmax激活函数结合使用,能够有效地提高模型的分类性能。
  • 应用场景:CrossEntropyLoss广泛应用于各类分类任务中, 例如:

    • 图像分类:如手写数字识别、物体识别等。
    • 自然语言处理:如文本分类、情感分析等。
    • 生物信息学:如基因序列分类等。

http://www.kler.cn/news/309925.html

相关文章:

  • Unity 百度AI实现无绿幕拍照抠像功能(详解版)
  • Flask-JWT-Extended登录验证
  • 构建常态化安全防线:XDR的态势感知与自动化响应机制
  • python学习笔记目录
  • JS全选反选案例
  • 海杂波分级方法
  • springboot项目中 前端浏览器访问时遇到跨域请求问题CORS怎么解决?has been blocked by CORS policy
  • 【UEFI基础】BIOS模块执行的优先级
  • 集成网口连接器国产化替代--RJ45内置网络变压器网口生产工厂在行动
  • HarmonyOS学习(十一)——安全管理
  • 说说synchronized的锁升级过程
  • 请求转发和重定向的区别
  • Eureka原理与实践:构建高效的微服务架构
  • 宠物空气净化器该怎么选?希喂、352、霍尼韦尔哪款对吸附浮毛有效
  • Python协程详解
  • uniapp中使用uni.$emit和uni.$on在vue和nvue页面之间传值但是无法赋值的问题
  • HarmonyOS 实现自定义启动页
  • 鸿蒙开发协调布局CollapsibleLayout
  • Unity3d 以鼠标位置点为中心缩放视角(正交模式下)
  • 待办: 杂七杂八——大杂烩.....懒得整理了,我自己凑合看
  • 新手学习Python第七天-新手笔记
  • 基于STM32C8T6的CubeMX:HAL库点亮LED
  • Datawhale X 李宏毅苹果书 AI夏令营 《深度学习详解》第十九章 ChatGPT
  • Python 入门教程(3)基础知识 | 3.6、标准输入与输出
  • c++----模板(进阶)
  • 什么是VHDX文件?
  • 国科云域名解析课堂:一个域名可以解析到多个IP地址吗?
  • 高校能耗管控方案如何做到节能减排
  • 【Python123题库】#绘制温度曲线 #XRD谱图绘制 #态密度曲线绘制
  • 3个WebSocket的.Net开源项目