PyTorch 损失函数解惑:为什么 nn.CrossEntropyLoss 和 nn.BCELoss 的公式看起来一样?
PyTorch 损失函数解惑:为什么 nn.CrossEntropyLoss
和 nn.BCELoss
的公式看起来一样?
在使用 PyTorch 时,我们经常会用到 nn.CrossEntropyLoss
(交叉熵损失)和 nn.BCELoss
/ nn.BCEWithLogitsLoss
(二元交叉熵损失)。如果你仔细看它们的公式,会发现它们长得几乎一模一样:
-
交叉熵损失(CrossEntropyLoss): (二分类时)
Loss = − 1 N ∑ i = 1 N [ y i log ( y ^ i ) + ( 1 − y i ) log ( 1 − y ^ i ) ] \text{Loss} = -\frac{1}{N} \sum_{i=1}^{N} [y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i)] Loss=−N1i=1∑N[yilog(y^i)+(1−yi)log(1−y^i)] -
二元交叉熵损失(BCELoss):
BCE = − 1 N ∑ i = 1 N [ y i log ( y ^ i ) + ( 1 − y i ) log ( 1 − y ^ i ) ] \text{BCE} = -\frac{1}{N} \sum_{i=1}^{N} [y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i)] BCE=−N1i=1∑N[yilog(y^i)+(1−yi)log(1−y^i)]
这不就是同一个公式吗?为什么 PyTorch 要分成两个不同的类呢?今天我们就来解开这个谜团,从数学原理到实现细节,彻底搞清楚它们的联系与区别。
1. 数学上的联系:交叉熵的本质
首先,我们得明白交叉熵(Cross-Entropy)的定义。交叉熵是信息论中的一个概念,用来衡量两个概率分布之间的差异。对于两个分布 ( p p p )(真实分布)和 ( q q q )(预测分布),交叉熵定义为:
H ( p , q ) = − ∑ x p ( x ) log q ( x ) H(p, q) = -\sum_{x} p(x) \log q(x) H(p,q)=−x∑p(x)logq(x)
- 在二分类任务中:
- 真实分布 ( p p p ) 是 0 或 1(比如 ( y i = 1 y_i = 1 yi=1 ) 表示正类,( y i = 0 y_i = 0 yi=0 ) 表示负类)。
- 预测分布 ( q q q ) 是模型输出的概率 ( y ^ i \hat{y}_i y^i )(介于 0 和 1 之间)。
- 交叉熵简化为:
( H ( p , q ) = − [ y i log ( y ^ i ) + ( 1 − y i ) log ( 1 − y ^ i ) ] H(p, q) = -[y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i)] H(p,q)=−[yilog(y^i)+(1−yi)log(1−y^i)] )
这正是二元交叉熵(Binary Cross-Entropy)的公式!所以,从数学上看,二元交叉熵就是交叉熵在二分类问题上的具体形式。
- 在多分类任务中:
- 真实分布 ( p p p ) 是 one-hot 编码(比如 ( [ 0 , 1 , 0 ] [0, 1, 0] [0,1,0] ) 表示第 1 类)。
- 预测分布 ( q q q ) 是模型输出的概率向量(比如 ( [ 0.1 , 0.7 , 0.2 ] [0.1, 0.7, 0.2] [0.1,0.7,0.2] ))。
- 交叉熵变成:
( H ( p , q ) = − ∑ c = 1 C y i , c log ( y ^ i , c ) H(p, q) = -\sum_{c=1}^{C} y_{i,c} \log(\hat{y}_{i,c}) H(p,q)=−∑c=1Cyi,clog(y^i,c) )
对于 one-hot 编码,只有正确类别的 ( y i , c = 1 y_{i,c} = 1 yi,c=1 ),其他为 0,所以简化为 ( − log ( y ^ i , correct ) -\log(\hat{y}_{i,\text{correct}}) −log(y^i,correct) )。
这说明:交叉熵是一个通用概念,二元交叉熵是它的特例,多分类交叉熵是它的扩展。
2. PyTorch 的实现差异
虽然数学上交叉熵和二元交叉熵有紧密联系,但 PyTorch 的 nn.CrossEntropyLoss
和 nn.BCELoss
在设计和使用上有明显区别,主要体现在以下几点:
(1) 输入的处理方式
-
nn.CrossEntropyLoss
:- 输入:未归一化的 logits(原始分数,比如
[1.0, 2.0, 0.5]
)。 - 内部操作:自动应用
LogSoftmax
(将 logits 转为概率)+NLLLoss
(负对数似然损失)。 - 目标:类别索引(比如
1
表示第 1 类)。 - 适用:多分类任务(类别数 ( C ≥ 2 C \geq 2 C≥2 ))。
- 输入:未归一化的 logits(原始分数,比如
-
nn.BCELoss
:- 输入:归一化后的概率(介于 0 和 1,比如
[0.7, 0.2]
),需要手动加 Sigmoid。 - 内部操作:直接计算二元交叉熵。
- 目标:0 或 1 的浮点数。
- 适用:二分类任务。
- 输入:归一化后的概率(介于 0 和 1,比如
-
nn.BCEWithLogitsLoss
:- 输入:未归一化的 logits。
- 内部操作:自动应用 Sigmoid + 二元交叉熵,比
nn.BCELoss
更稳定。 - 目标:0 或 1。
- 适用:二分类任务。
关键区别:nn.CrossEntropyLoss
针对多分类,包含 Softmax 操作;nn.BCELoss
和 nn.BCEWithLogitsLoss
针对二分类,分别处理概率和 logits。
(2) 使用场景的差异
- 如果你做一个手写数字识别(10 类),用
nn.CrossEntropyLoss
,因为它是多分类任务。 - 如果你判断图片里有没有猫(0 或 1),用
nn.BCEWithLogitsLoss
,因为它是二分类任务。
(3) 公式“看起来一样”的原因
在二分类情况下,nn.CrossEntropyLoss
和 nn.BCEWithLogitsLoss
的数学形式确实可以等价:
- 假设有两个类别(0 和 1),
nn.CrossEntropyLoss
的输入是[logit_0, logit_1]
,经过 Softmax 后:- ( y ^ 1 = e logit 1 e logit 0 + e logit 1 \hat{y}_1 = \frac{e^{\text{logit}_1}}{e^{\text{logit}_0} + e^{\text{logit}_1}} y^1=elogit0+elogit1elogit1 )
- ( y ^ 0 = 1 − y ^ 1 \hat{y}_0 = 1 - \hat{y}_1 y^0=1−y^1 )
- 损失为 ( − log ( y ^ correct ) -\log(\hat{y}_\text{correct}) −log(y^correct) )。
- 而
nn.BCEWithLogitsLoss
输入是单个 logit,经过 Sigmoid:- ( y ^ = 1 1 + e − logit \hat{y} = \frac{1}{1 + e^{-\text{logit}}} y^=1+e−logit1 )
- 损失为 ( − [ y log ( y ^ ) + ( 1 − y ) log ( 1 − y ^ ) ] -[y \log(\hat{y}) + (1 - y) \log(1 - \hat{y})] −[ylog(y^)+(1−y)log(1−y^)] )。
当类别数为 2 时,Softmax 和 Sigmoid 的结果可以互相转换,数学上损失值一致。但 PyTorch 实现时,nn.CrossEntropyLoss
处理多维 logits,nn.BCEWithLogitsLoss
针对标量二分类,接口和语义不同。
3. 代码对比:直观感受区别
来看一个简单的二分类例子:
import torch
import torch.nn as nn
# 数据
logits = torch.tensor([[1.0, 2.0], [0.5, -0.5]]) # [batch_size, 2]
target = torch.tensor([1, 0]) # 类别索引
# nn.CrossEntropyLoss
ce_loss_fn = nn.CrossEntropyLoss()
ce_loss = ce_loss_fn(logits, target)
print("CrossEntropyLoss:", ce_loss.item())
# nn.BCEWithLogitsLoss
bce_logits_fn = nn.BCEWithLogitsLoss()
logits_for_bce = logits[:, 1] - logits[:, 0] # 转换为单值 logit
target_for_bce = target.float()
bce_loss = bce_logits_fn(logits_for_bce, target_for_bce)
print("BCEWithLogitsLoss:", bce_loss.item())
nn.CrossEntropyLoss
直接用[logit_0, logit_1]
,目标是索引。nn.BCEWithLogitsLoss
需要把 logits 转换为单值(比如 ( logit 1 − logit 0 \text{logit}_1 - \text{logit}_0 logit1−logit0 )),目标是 0 或 1。
输出值可能略有不同,但数学上它们在二分类时等价。
4. 为什么分开设计?
既然二分类是多分类的特例,为什么 PyTorch 不统一用 nn.CrossEntropyLoss
?
- 语义清晰:二分类和多分类的任务需求不同,分开设计更直观。
- 计算效率:二分类用 Sigmoid 比 Softmax 更简单,节省计算量。
- 使用习惯:二分类任务常输出概率,
nn.BCELoss
符合传统机器学习习惯。
5. 小结:公式相同,设计不同
- 数学上:
nn.CrossEntropyLoss
和nn.BCELoss
的公式在二分类时一致,因为它们都源于交叉熵。 - 实现上:
nn.CrossEntropyLoss
针对多分类,处理 logits + Softmax。nn.BCELoss
和nn.BCEWithLogitsLoss
针对二分类,分别处理概率和 logits。
- 选择建议:
- 多分类(( C > 2 )):用
nn.CrossEntropyLoss
。 - 二分类:用
nn.BCEWithLogitsLoss
(更稳定)或nn.BCELoss
(手动 Sigmoid)。
- 多分类(( C > 2 )):用
后记
2025年2月28日17点05分于上海,在grok3 大模型辅助下完成。