pytorch交叉熵损失函数
nn.CrossEntropyLoss
是 PyTorch 中非常常用的损失函数,特别适用于分类任务。它结合了 nn.LogSoftmax
和 nn.NLLLoss
(负对数似然损失)的功能,可以直接处理未经过 softmax 的 logits 输出,计算预测值与真实标签之间的交叉熵损失。
1. 交叉熵损失的原理
交叉熵损失衡量的是两个概率分布之间的差异。在分类任务中,模型输出的 logits 通过 softmax 转换成概率分布,然后与真实标签的概率分布进行比较。交叉熵损失会鼓励模型输出的概率分布尽可能接近真实标签的概率分布。
对于一个类别标签 y
,预测概率 p(y)
,交叉熵损失定义为:
对于一个多分类任务,如果真实标签是 y
,预测的 logits 是 z_i
,则交叉熵损失计算为:
其中 z_y
是模型输出的与真实类别对应的 logit 值,分母是所有类别的 logits 的指数和。