当前位置: 首页 > article >正文

PyTorch 损失函数解惑:为什么 nn.CrossEntropyLoss 和 nn.BCELoss 的公式看起来一样?

PyTorch 损失函数解惑:为什么 nn.CrossEntropyLossnn.BCELoss 的公式看起来一样?

在使用 PyTorch 时,我们经常会用到 nn.CrossEntropyLoss(交叉熵损失)和 nn.BCELoss / nn.BCEWithLogitsLoss(二元交叉熵损失)。如果你仔细看它们的公式,会发现它们长得几乎一模一样:

  • 交叉熵损失(CrossEntropyLoss): (二分类时)
    Loss = − 1 N ∑ i = 1 N [ y i log ⁡ ( y ^ i ) + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) ] \text{Loss} = -\frac{1}{N} \sum_{i=1}^{N} [y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i)] Loss=N1i=1N[yilog(y^i)+(1yi)log(1y^i)]

  • 二元交叉熵损失(BCELoss)
    BCE = − 1 N ∑ i = 1 N [ y i log ⁡ ( y ^ i ) + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) ] \text{BCE} = -\frac{1}{N} \sum_{i=1}^{N} [y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i)] BCE=N1i=1N[yilog(y^i)+(1yi)log(1y^i)]

这不就是同一个公式吗?为什么 PyTorch 要分成两个不同的类呢?今天我们就来解开这个谜团,从数学原理到实现细节,彻底搞清楚它们的联系与区别。

1. 数学上的联系:交叉熵的本质

首先,我们得明白交叉熵(Cross-Entropy)的定义。交叉熵是信息论中的一个概念,用来衡量两个概率分布之间的差异。对于两个分布 ( p p p )(真实分布)和 ( q q q )(预测分布),交叉熵定义为:

H ( p , q ) = − ∑ x p ( x ) log ⁡ q ( x ) H(p, q) = -\sum_{x} p(x) \log q(x) H(p,q)=xp(x)logq(x)

  • 二分类任务中:
    • 真实分布 ( p p p ) 是 0 或 1(比如 ( y i = 1 y_i = 1 yi=1 ) 表示正类,( y i = 0 y_i = 0 yi=0 ) 表示负类)。
    • 预测分布 ( q q q ) 是模型输出的概率 ( y ^ i \hat{y}_i y^i )(介于 0 和 1 之间)。
    • 交叉熵简化为:
      ( H ( p , q ) = − [ y i log ⁡ ( y ^ i ) + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) ] H(p, q) = -[y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i)] H(p,q)=[yilog(y^i)+(1yi)log(1y^i)] )

这正是二元交叉熵(Binary Cross-Entropy)的公式!所以,从数学上看,二元交叉熵就是交叉熵在二分类问题上的具体形式。

  • 多分类任务中:
    • 真实分布 ( p p p ) 是 one-hot 编码(比如 ( [ 0 , 1 , 0 ] [0, 1, 0] [0,1,0] ) 表示第 1 类)。
    • 预测分布 ( q q q ) 是模型输出的概率向量(比如 ( [ 0.1 , 0.7 , 0.2 ] [0.1, 0.7, 0.2] [0.1,0.7,0.2] ))。
    • 交叉熵变成:
      ( H ( p , q ) = − ∑ c = 1 C y i , c log ⁡ ( y ^ i , c ) H(p, q) = -\sum_{c=1}^{C} y_{i,c} \log(\hat{y}_{i,c}) H(p,q)=c=1Cyi,clog(y^i,c) )
      对于 one-hot 编码,只有正确类别的 ( y i , c = 1 y_{i,c} = 1 yi,c=1 ),其他为 0,所以简化为 ( − log ⁡ ( y ^ i , correct ) -\log(\hat{y}_{i,\text{correct}}) log(y^i,correct) )。

这说明:交叉熵是一个通用概念,二元交叉熵是它的特例,多分类交叉熵是它的扩展。

2. PyTorch 的实现差异

虽然数学上交叉熵和二元交叉熵有紧密联系,但 PyTorch 的 nn.CrossEntropyLossnn.BCELoss 在设计和使用上有明显区别,主要体现在以下几点:

(1) 输入的处理方式
  • nn.CrossEntropyLoss

    • 输入:未归一化的 logits(原始分数,比如 [1.0, 2.0, 0.5])。
    • 内部操作:自动应用 LogSoftmax(将 logits 转为概率)+ NLLLoss(负对数似然损失)。
    • 目标:类别索引(比如 1 表示第 1 类)。
    • 适用:多分类任务(类别数 ( C ≥ 2 C \geq 2 C2 ))。
  • nn.BCELoss

    • 输入:归一化后的概率(介于 0 和 1,比如 [0.7, 0.2]),需要手动加 Sigmoid。
    • 内部操作:直接计算二元交叉熵。
    • 目标:0 或 1 的浮点数。
    • 适用:二分类任务。
  • nn.BCEWithLogitsLoss

    • 输入:未归一化的 logits。
    • 内部操作:自动应用 Sigmoid + 二元交叉熵,比 nn.BCELoss 更稳定。
    • 目标:0 或 1。
    • 适用:二分类任务。

关键区别nn.CrossEntropyLoss 针对多分类,包含 Softmax 操作;nn.BCELossnn.BCEWithLogitsLoss 针对二分类,分别处理概率和 logits。

(2) 使用场景的差异
  • 如果你做一个手写数字识别(10 类),用 nn.CrossEntropyLoss,因为它是多分类任务。
  • 如果你判断图片里有没有猫(0 或 1),用 nn.BCEWithLogitsLoss,因为它是二分类任务。
(3) 公式“看起来一样”的原因

在二分类情况下,nn.CrossEntropyLossnn.BCEWithLogitsLoss 的数学形式确实可以等价:

  • 假设有两个类别(0 和 1),nn.CrossEntropyLoss 的输入是 [logit_0, logit_1],经过 Softmax 后:
    • ( y ^ 1 = e logit 1 e logit 0 + e logit 1 \hat{y}_1 = \frac{e^{\text{logit}_1}}{e^{\text{logit}_0} + e^{\text{logit}_1}} y^1=elogit0+elogit1elogit1 )
    • ( y ^ 0 = 1 − y ^ 1 \hat{y}_0 = 1 - \hat{y}_1 y^0=1y^1 )
    • 损失为 ( − log ⁡ ( y ^ correct ) -\log(\hat{y}_\text{correct}) log(y^correct) )。
  • nn.BCEWithLogitsLoss 输入是单个 logit,经过 Sigmoid:
    • ( y ^ = 1 1 + e − logit \hat{y} = \frac{1}{1 + e^{-\text{logit}}} y^=1+elogit1 )
    • 损失为 ( − [ y log ⁡ ( y ^ ) + ( 1 − y ) log ⁡ ( 1 − y ^ ) ] -[y \log(\hat{y}) + (1 - y) \log(1 - \hat{y})] [ylog(y^)+(1y)log(1y^)] )。

当类别数为 2 时,Softmax 和 Sigmoid 的结果可以互相转换,数学上损失值一致。但 PyTorch 实现时,nn.CrossEntropyLoss 处理多维 logits,nn.BCEWithLogitsLoss 针对标量二分类,接口和语义不同。

3. 代码对比:直观感受区别

来看一个简单的二分类例子:

import torch
import torch.nn as nn

# 数据
logits = torch.tensor([[1.0, 2.0], [0.5, -0.5]])  # [batch_size, 2]
target = torch.tensor([1, 0])  # 类别索引

# nn.CrossEntropyLoss
ce_loss_fn = nn.CrossEntropyLoss()
ce_loss = ce_loss_fn(logits, target)
print("CrossEntropyLoss:", ce_loss.item())

# nn.BCEWithLogitsLoss
bce_logits_fn = nn.BCEWithLogitsLoss()
logits_for_bce = logits[:, 1] - logits[:, 0]  # 转换为单值 logit
target_for_bce = target.float()
bce_loss = bce_logits_fn(logits_for_bce, target_for_bce)
print("BCEWithLogitsLoss:", bce_loss.item())
  • nn.CrossEntropyLoss 直接用 [logit_0, logit_1],目标是索引。
  • nn.BCEWithLogitsLoss 需要把 logits 转换为单值(比如 ( logit 1 − logit 0 \text{logit}_1 - \text{logit}_0 logit1logit0 )),目标是 0 或 1。

输出值可能略有不同,但数学上它们在二分类时等价。

4. 为什么分开设计?

既然二分类是多分类的特例,为什么 PyTorch 不统一用 nn.CrossEntropyLoss

  • 语义清晰:二分类和多分类的任务需求不同,分开设计更直观。
  • 计算效率:二分类用 Sigmoid 比 Softmax 更简单,节省计算量。
  • 使用习惯:二分类任务常输出概率,nn.BCELoss 符合传统机器学习习惯。
5. 小结:公式相同,设计不同
  • 数学上nn.CrossEntropyLossnn.BCELoss 的公式在二分类时一致,因为它们都源于交叉熵。
  • 实现上
    • nn.CrossEntropyLoss 针对多分类,处理 logits + Softmax。
    • nn.BCELossnn.BCEWithLogitsLoss 针对二分类,分别处理概率和 logits。
  • 选择建议
    • 多分类(( C > 2 )):用 nn.CrossEntropyLoss
    • 二分类:用 nn.BCEWithLogitsLoss(更稳定)或 nn.BCELoss(手动 Sigmoid)。

后记

2025年2月28日17点05分于上海,在grok3 大模型辅助下完成。


http://www.kler.cn/a/567587.html

相关文章:

  • fluent-ffmpeg 依赖详解
  • oracle使用PLSQL导出表数据
  • 【FL0087】基于SSM和微信小程序的民宿短租系统
  • Spring Boot 3 集成 RabbitMQ 实践指南
  • AnyDesk 远程桌面控制软件 v9.0.2
  • 数据结构之八大排序算法详解
  • QT基础十、表格组件:QTableWidget
  • JavaScript系列02-函数深入理解
  • 通过统计学视角解读机器学习:从贝叶斯到正则化
  • 华为在不同发展时期的战略选择(节选)
  • Java多线程与高并发专题——深入ReentrantReadWriteLock
  • Python 数据可视化(一)熟悉Matplotlib
  • iOS中的设计模式(六)- 单利模式
  • 问题解决:word导出的pdf图片不清晰?打印机导出的不是pdf,是.log文本文档?
  • 性能测试丨JMeter 分布式加压机制
  • uniapp 阿里云点播 播放bug
  • 目标检测——数据处理
  • 前端清除浮动有哪些方式?
  • 微服务即时通信系统---(七)文件管理子服务
  • 关于延迟任务线程池,Java提供的ScheduledThreadPoolExecutor,Spring提供的ThreadPoolTaskScheduler