当前位置: 首页 > article >正文

PyTorch 的 nn.NLLLoss:负对数似然损失全解析

PyTorch 的 nn.NLLLoss:负对数似然损失全解析

在 PyTorch 的损失函数家族中,nn.NLLLoss(Negative Log Likelihood Loss,负对数似然损失)是一个不太起眼但非常重要的成员。它经常跟 LogSoftmax 搭配出现,尤其在分类任务中扮演关键角色。今天我们就来聊聊 nn.NLLLoss 的数学原理、使用方法,以及它适用的场景,带你彻底搞懂这个损失函数。

1. 什么是负对数似然损失?

先从名字拆解:

  • 似然(Likelihood):在统计学中,似然表示“给定模型参数时,观察到数据的概率”。对数似然(Log Likelihood)是它的对数形式,常用于简化计算。
  • 负对数似然(Negative Log Likelihood, NLL):把对数似然取负数,作为损失函数,目标是最小化它。

在机器学习中,负对数似然损失通常用来衡量模型预测的概率分布与真实标签的差距,尤其是在分类任务中。

数学公式

假设我们有一个多分类任务,有 ( C C C ) 个类别。对于一个样本:

  • ( y ^ \hat{y} y^ ) 是模型输出的概率分布,比如经过 Softmax 或 LogSoftmax 处理后的结果。
  • ( y y y ) 是真实类别,用索引表示(比如 2 表示第 2 类)。

nn.NLLLoss 的公式是:

NLL = − 1 N ∑ i = 1 N log ⁡ ( y ^ i , y i ) \text{NLL} = -\frac{1}{N} \sum_{i=1}^{N} \log(\hat{y}_{i, y_i}) NLL=N1i=1Nlog(y^i,yi)

  • ( N N N ):样本数量(batch size)。
  • ( y i y_i yi ):第 ( i i i ) 个样本的真实类别索引。
  • ( y ^ i , y i \hat{y}_{i, y_i} y^i,yi ):第 ( i i i ) 个样本在真实类别 ( y i y_i yi ) 上的预测概率(对数值)。

简单来说,nn.NLLLoss 取预测概率的对数(已经由 LogSoftmax 计算好),然后取负号,只关心正确类别的概率值。

2. 为什么搭配 LogSoftmax

你可能会注意到,nn.NLLLoss 的文档里总是提到“通常与 LogSoftmax 搭配使用”。这是为什么?

  • 模型输出:神经网络的最后一层通常输出未归一化的 logits(比如 [1.0, 2.0, 0.5]),而不是概率。
  • Softmax:将 logits 转为概率分布,比如 [0.2, 0.5, 0.3],满足 ( ∑ y ^ = 1 \sum \hat{y} = 1 y^=1)。公式是:
    y ^ j = e z j ∑ k = 1 C e z k \hat{y}_j = \frac{e^{z_j}}{\sum_{k=1}^{C} e^{z_k}} y^j=k=1Cezkezj
  • LogSoftmax:在 Softmax 基础上取对数,输出的是对数概率,比如 [-1.6, -0.7, -1.2]。公式是:
    log ⁡ ( y ^ j ) = z j − log ⁡ ( ∑ k = 1 C e z k ) \log(\hat{y}_j) = z_j - \log(\sum_{k=1}^{C} e^{z_k}) log(y^j)=zjlog(k=1Cezk)

nn.NLLLoss 要求输入是对数概率(log probabilities),而不是原始概率。所以:

  • 如果你直接给它 Softmax 后的概率,会出错,因为它期待的是 ( log ⁡ ( y ^ ) \log(\hat{y}) log(y^))。
  • LogSoftmax 处理后,输入正好符合要求,计算时直接取负号即可。
3. 代码使用示例

我们来看一个简单的例子,展示 nn.NLLLossLogSoftmax 的搭配:

import torch
import torch.nn as nn

# 假设一个 3 分类任务,batch_size = 2
logits = torch.tensor([[1.0, 2.0, 0.5], [0.1, 0.5, 2.0]])  # 原始 logits
target = torch.tensor([1, 2])  # 真实类别索引,0~2

# 定义 LogSoftmax 和 NLLLoss
log_softmax = nn.LogSoftmax(dim=1)  # dim=1 表示在类别维度上归一化
loss_fn = nn.NLLLoss()

# 计算损失
log_probs = log_softmax(logits)  # 先转为对数概率
loss = loss_fn(log_probs, target)
print("NLL Loss:", loss.item())

运行过程

  1. logits[batch_size, num_classes] 的张量,表示每个样本在每个类别上的得分。
  2. nn.LogSoftmax 把 logits 转为对数概率,比如 [[-1.9, -0.9, -2.4], [-2.3, -1.9, -0.4]]
  3. nn.NLLLoss 提取每个样本在真实类别上的对数概率(比如第一个样本取 -0.9,第二个取 -0.4),取负并平均。

输出可能是 1.15,具体值取决于输入。

4. 与 nn.CrossEntropyLoss 的关系

你可能听说过 nn.CrossEntropyLoss,它也很常见。实际上:

  • nn.CrossEntropyLoss = LogSoftmax + nn.NLLLoss
    PyTorch 把这两步合二为一,直接接受 logits 作为输入,内部自动完成 LogSoftmax 和 NLL 计算。具体过程可以参考笔者的另一篇博客:Pytorch为什么 nn.CrossEntropyLoss = LogSoftmax + nn.NLLLoss?

代码对比:

# 用 nn.CrossEntropyLoss
ce_loss_fn = nn.CrossEntropyLoss()
ce_loss = ce_loss_fn(logits, target)
print("CrossEntropyLoss:", ce_loss.item())  # 与上面结果相同
  • 区别
    • nn.NLLLoss:输入是对数概率,需手动加 LogSoftmax
    • nn.CrossEntropyLoss:输入是 logits,自动处理。
5. 使用场景

nn.NLLLoss 适用于以下场景:

  • 多分类任务:比如图像分类(CIFAR-10 的 10 类)、文本分类。
  • 需要分离 Softmax 的情况
    • 你想在模型里显式控制 LogSoftmax 的位置,而不是交给损失函数。
    • 调试时单独检查对数概率的值。
  • 概率输出的模型:如果你的模型已经输出对数概率(比如某些预训练模型),直接用 nn.NLLLoss 更高效。

典型例子

  • 一个简单的 CNN 分类器:
    class SimpleCNN(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(1, 16, 3)
            self.fc = nn.Linear(16 * 26 * 26, 10)  # 假设 28x28 输入
            self.log_softmax = nn.LogSoftmax(dim=1)
    
        def forward(self, x):
            x = self.conv(x)
            x = x.view(x.size(0), -1)
            x = self.fc(x)
            return self.log_softmax(x)
    
    model = SimpleCNN()
    loss_fn = nn.NLLLoss()
    
    这里模型输出对数概率,搭配 nn.NLLLoss 计算损失。
6. 注意事项
  • 输入形状
    • 输入:[batch_size, num_classes](对数概率)。
    • 目标:[batch_size](类别索引)。
  • 目标类型:必须是整数(long 类型),不能是 one-hot 或浮点数。
  • 数值稳定性LogSoftmax 比单独的 Softmax + log 更稳定,因为它避免了溢出问题。
7. 小结:nn.NLLLoss 的核心
  • 数学原理:计算正确类别对数概率的负值,最小化它等价于最大化似然。
  • 使用方式:搭配 LogSoftmax,输入对数概率,输出标量损失。
  • 场景:多分类任务,尤其是需要显式控制概率计算时。
  • CrossEntropyLoss 的关系:前者是后者的组成部分,功能更模块化。

nn.NLLLoss 就像一个“半成品”,需要你自己搭配 LogSoftmax,但这也给了你更多灵活性。相比直接用 nn.CrossEntropyLoss,它更适合喜欢拆解步骤或调试模型的开发者。

8. 调试小技巧
  • 检查输入:打印 log_probs 确保是对数概率(负值)。
  • 验证目标:确保 target 是整数,且范围在 [0, num_classes-1]
  • 对比结果:用 nn.CrossEntropyLoss 验证是否一致。

希望这篇博客让你对 nn.NLLLoss 有了全面认识!

后记

2025年2月28日18点59分于上海,在Grok3大模型辅助下完成。


http://www.kler.cn/a/572417.html

相关文章:

  • electron.vite + better-sqlite3 + serialport 完整使用教程
  • Qt/C++音视频开发-检查是否含有B帧/转码推流/拉流显示/监控拉流推流/海康大华宇视监控
  • 基于Python的新闻采集与分析:新闻平台的全面数据采集实践
  • 爬虫技术结合淘宝商品快递费用API接口(item_fee):电商物流数据的高效获取与应用
  • 用DeepSeek-R1-Distill-data-110k蒸馏中文数据集 微调Qwen2.5-7B-Instruct!
  • 【leetcode】实现Tire(前缀树)
  • FastGPT 源码:基于 LLM 实现 Rerank (含Prompt)
  • android_viewtracker 原理
  • 【cuda学习日记】5.4 常量内存
  • leetcode383 赎金信
  • 【详解 | 辨析】“单跳多跳,单天线多天线,单信道多信道” 之间的对比
  • Git-cherry pick
  • 迷你世界脚本世界UI接口:UI
  • c++面试常见问题:虚表指针存在于内存哪个分区
  • Node.js学习分享(上)
  • python爬虫数据库概述
  • 【Java】IO流
  • Linux·数据库INSERT优化
  • PyTorch 与 NVIDIA GPU 的适配版本及安装
  • NO.23十六届蓝桥杯备战|二维数组|创建|初始化|遍历|memset(C++)