当前位置: 首页 > article >正文

Pytorch常用内置损失函数合集

       PyTorch 提供了多种内置的损失函数,适用于不同的任务和场景。这些损失函数通常已经优化并实现了常见的归约方式(如 meansum),并且可以直接用于训练模型。以下是常见的 PyTorch 内置损失函数及其适用场景:

1. 均方误差损失(Mean Squared Error, MSE)

  • 类名nn.MSELoss

  • 公式

    其中 N 是样本数量,yi是真实值,y^i是预测值;

  • 适用场景

    • 回归问题:当目标是预测连续值时,MSE 是最常见的损失函数。它衡量预测值与真实值之间的平方差,并对较大的误差施加更大的惩罚。
    • 时间序列预测:在时间序列预测任务中,MSE 也常用于衡量模型的预测性能。
  • 示例代码

    loss_fn = nn.MSELoss()

2. 二元交叉熵损失(Binary Cross-Entropy, BCE)

  • 类名nn.BCELoss

  • 公式

    其中 N是样本数量,yi 是真实标签(0 或 1),y^i 是预测的概率值(介于 0 和 1 之间)。

  • 适用场景

    • 二分类问题:当目标是将输入分为两个类别时,BCE 是常用的损失函数。它衡量预测概率与真实标签之间的差异。
    • 多标签分类:在多标签分类任务中,每个样本可以属于多个类别,BCE 可以用于每个标签的独立预测。
  • 注意事项

    • 预测值应为概率值(介于 0 和 1 之间)。如果你的模型输出是未经过激活函数的 logits,应该使用 nn.BCEWithLogitsLoss,它会自动应用 Sigmoid 激活函数。
  • 示例代码

loss_fn = nn.BCELoss()

 

3. 带逻辑斯蒂回归的二元交叉熵损失(BCE with Logits)

  • 类名nn.BCEWithLogitsLoss

  • 公式

    其中 σ 是 Sigmoid 函数,y^i 是模型输出的 logits(未经过 Sigmoid 激活的值)。

  • 适用场景

    • 二分类问题:类似于 nn.BCELoss,但它直接接受未经过 Sigmoid 激活的 logits,并在内部应用 Sigmoid 激活函数。这可以提高数值稳定性。
    • 多标签分类:同样适用于多标签分类任务。
  • 优点

    • 数值更稳定,因为 Sigmoid 和 BCE 的计算是在同一层完成的,避免了梯度消失或爆炸的问题。
  • 示例代码

    loss_fn = nn.BCEWithLogitsLoss()

     

4. 多分类交叉熵损失(Cross Entropy Loss)

  • 类名nn.CrossEntropyLoss

  • 公式

    其中 yi​ 是真实标签(整数表示类别),y^i是模型输出的 logits(未经过 Softmax 激活的值)。

  • 适用场景

    • 多分类问题:当目标是将输入分为多个类别时,Cross Entropy 是常用的损失函数。它结合了 Softmax 激活函数和负对数似然损失(NLL),适合处理多分类任务。
    • 图像分类:在图像分类任务中,Cross Entropy 是最常用的选择。
  • 注意事项

    • 预测值应为 logits(未经过 Softmax 激活的值)。nn.CrossEntropyLoss 会在内部自动应用 Softmax 激活函数。
    • 真实标签应为整数表示的类别索引,而不是 one-hot 编码。
  • 示例代码

loss_fn = nn.CrossEntropyLoss()

 

5. 负对数似然损失(Negative Log Likelihood, NLL)

  • 类名nn.NLLLoss

  • 公式

    其中 yi​ 是真实标签(整数表示类别),pi​ 是预测的概率分布(经过 Softmax 激活后的值)。

  • 适用场景

    • 多分类问题:类似于 nn.CrossEntropyLoss,但 nn.NLLLoss 需要输入已经是经过 Softmax 激活的概率分布。因此,通常与 nn.LogSoftmax 一起使用。
    • 自定义激活函数:如果你希望在损失函数之前应用自定义的激活函数(如温度缩放的 Softmax),可以使用 nn.NLLLoss
  • 示例代码

# 使用 LogSoftmax 和 NLLLoss
m = nn.LogSoftmax(dim=1)
loss_fn = nn.NLLLoss()
output = m(logits)
loss = loss_fn(output, target)

 

6. L1 损失(L1 Loss, Mean Absolute Error, MAE)

  • 类名nn.L1Loss

  • 公式

    其中 N 是样本数量,yi 是真实值,y^i 是预测值。

  • 适用场景

    • 回归问题:与 MSE 类似,L1 损失用于回归任务,但它对异常值(outliers)不太敏感,因为它使用绝对差而不是平方差。
    • 鲁棒性要求较高的任务:当你希望模型对异常值具有更好的鲁棒性时,L1 损失是一个不错的选择。
  • 示例代码

 

loss_fn = nn.L1Loss()

7. Smooth L1 损失(Huber Loss)

  • 类名nn.SmoothL1Loss

  • 公式

    其中 x=yi−y^i​ 是预测值与真实值之间的差异。

  • 适用场景

    • 回归问题:Smooth L1 损失结合了 MSE 和 L1 损失的优点。对于小误差,它使用平方差(类似于 MSE),而对于大误差,它使用绝对差(类似于 L1)。这使得它对异常值具有一定的鲁棒性,同时保持了 MSE 的平滑性。
    • 目标检测:在目标检测任务中,Smooth L1 损失常用于回归边界框的坐标。
  • 示例代码

loss_fn = nn.SmoothL1Loss()

8. Kullback-Leibler 散度损失(KL Divergence)

  • 类名nn.KLDivLoss

  • 公式

    其中 P 是真实分布,Q 是预测分布;

  • 适用场景

    • 分布匹配:当目标是使预测分布尽可能接近真实分布时,KL 散度是一个常用的损失函数。它衡量两个分布之间的差异。
    • 生成对抗网络(GANs):在 GAN 中,KL 散度常用于衡量生成分布与真实分布之间的差异。
    • 变分自编码器(VAEs):在 VAE 中,KL 散度用于正则化潜在变量的分布,使其接近标准正态分布。
  • 注意事项

    • 输入应为对数概率分布(即经过 nn.LogSoftmax 处理的值),而目标应为概率分布。
  • 示例代码

loss_fn = nn.KLDivLoss(reduction='batchmean')

9. Hinge 损失(Hinge Loss)

  • 类名nn.HingeEmbeddingLoss

  • 公式

    其中 y 是真实标签(1 或 -1),y^​ 是预测值。

  • 适用场景

    • 二分类问题:Hinge 损失常用于支持向量机(SVM)中,尤其是在二分类任务中。它鼓励模型将正类和负类之间的间隔最大化。
    • 度量学习:在度量学习任务中,Hinge 损失用于鼓励相似样本之间的距离最小化,而不相似样本之间的距离最大化。
  • 示例代码

loss_fn = nn.HingeEmbeddingLoss()

10. Cosine 相似度损失(Cosine Embedding Loss)

  • 类名nn.CosineEmbeddingLoss

  • 公式

    其中 x1​ 和 x2​ 是两个输入向量,y 是标签(1 表示相似,-1 表示不相似);

  • 适用场景

    • 度量学习:Cosine Embedding Loss 用于度量学习任务,鼓励相似样本之间的余弦相似度最大化,而不相似样本之间的余弦相似度最小化。
    • 对比学习:在对比学习任务中,Cosine Embedding Loss 用于拉近正样本对的距离,推远负样本对的距离。
  • 示例代码

loss_fn = nn.CosineEmbeddingLoss(margin=0.5)

 

 

 


http://www.kler.cn/a/445121.html

相关文章:

  • 百度二面,MySQL 怎么做权重搜索?
  • JVM学习-内存结构(二)
  • 为什么深度学习和神经网络要使用 GPU?
  • 基于顺序表实现队列循环队列的处理
  • SQL 基础教程
  • 拉链表,流⽔表以及快照表的含义和特点
  • 【Elasticsearch03】企业级日志分析系统ELK之Elasticsearch访问与优化
  • BI 工具与 NoETL 自动化指标平台在自助数据分析的差异
  • element table 表头header-cell-style设置的表头不生效
  • 移动魔百盒中的 OpenWrt作为旁路由 安装Tailscale并配置子网路由实现在外面通过家里的局域网ip访问内网设备
  • 每日十题八股-2024年12月18日
  • 亚马逊云科技 re:Invent 2024重磅发布!Amazon Bedrock Data Automation 预览版震撼登场
  • 深度学习0-前置知识
  • 道路运输企业安全生产管理人员安全考核试题
  • 【网络安全设备系列】7、流量监控设备
  • 华为云联合中国信通院发布首个云计算智能化可观测性能力成熟度模型标准
  • Group FLUX - Beta Sprint Summary Essay
  • Vue中Axios二次封装
  • 主曲率为常数时曲面分类
  • uniApp使用腾讯地图提示未添加maps模块
  • 设计模式--单例模式【创建型模式】
  • uniapp图片数据流���� JFIF ��C 转化base64
  • Ubuntu将深度学习环境配置移植到新电脑
  • 分布式锁介绍
  • Spark 运行时对哪些数据会做缓存?
  • 怎样衡量电阻负载的好坏