当前位置: 首页 > article >正文

深度学习-交叉熵损失函数

        交叉熵损失函数(Cross-Entropy Loss)是机器学习和深度学习中常用的一种损失函数,特别适用于分类问题。它用于评估模型的预测概率分布和真实分布之间的差异,能够衡量模型预测的准确性。一般来说,交叉熵损失越小,模型的预测结果就越接近真实标签。

1. 什么是交叉熵(Cross-Entropy)

交叉熵来源于信息论,用来衡量两个概率分布之间的差异。若已知一个真实分布 P和一个预测分布 Q,交叉熵可以被定义为:

其中:

  • P(i)P是真实分布中第 i 类的概率;
  • Q(i)是预测分布中第 i 类的概率。

交叉熵的值越小,说明 Q越接近 P,即模型的预测分布与真实分布越一致;反之,交叉熵的值越大,则表明预测分布与真实分布差异越大。

2. 交叉熵损失函数在机器学习中的应用

在分类问题中,我们通常知道每个样本属于某一个类别,可以将其表示成一个概率分布。例如,对于一个二分类问题,标签可以是 [1, 0](第一个类别)或 [0, 1](第二个类别),这表示样本属于某一类的“确定”概率。我们希望模型的预测结果也接近于这种“确定”的分布。

交叉熵损失函数将模型的输出概率分布与真实标签的概率分布做比较,进而度量模型预测的误差。此损失函数在分类任务中特别常用,尤其是多分类任务。

3. 交叉熵损失函数的公式推导

3.1 多分类任务中的交叉熵损失

对于多分类任务,假设有 C 个类别,真实标签为 y,模型的预测输出为 ,交叉熵损失函数定义为:

其中:

  • yi是真实标签第 i 类的值,通常用独热编码表示,即只有真实类别对应的 yi 等于 1,其余为 0。
  • i是模型输出的第 i 类的概率。

该公式的意义在于,仅考虑实际类别的预测概率;非真实类别的预测不会影响损失值。

3.2 二分类任务中的交叉熵损失

对于二分类任务,交叉熵损失公式可以简化。设样本标签 ,模型预测的概率为 ,则交叉熵损失为:

其中:

  • 当 y=1时,损失函数变为
  • 当 y=0时,损失函数变为

这两个公式描述了在二分类情况下预测值与真实标签的偏差。

4. 为什么交叉熵损失函数有效?

4.1 强制模型做出明确的分类

交叉熵损失鼓励模型输出更“确定”的分类,即倾向于给真实类别分配更高的概率,从而在训练过程中优化模型的分类效果。

4.2 梯度良好,利于优化

交叉熵损失函数在实际使用中配合 softmax 层,softmax 输出的概率具有可微性,这样可以通过反向传播有效地进行梯度计算,有助于提升训练的效率和模型的收敛性。

4.3 与最大似然估计的关系

交叉熵损失函数等价于最大似然估计(MLE)的负对数似然,最大化预测结果的概率也就意味着最小化交叉熵损失。因此,交叉熵损失也是在模型训练中优化参数的一种直观方法。

5. 交叉熵损失的优缺点

优点
  1. 适合分类问题:特别适用于分类任务,能够很好地衡量模型的输出概率分布与真实分布的相似程度。
  2. 有效的梯度传播:在反向传播中,交叉熵损失的导数相对稳定,有助于优化的收敛。
  3. 概率解释:通过 softmax 层输出概率分布,能够直观理解模型对不同类别的预测信心。
缺点
  1. 对异常点敏感:交叉熵损失会对错误的高置信预测(如模型对错类别预测出高概率)产生高损失,导致对异常样本过于敏感。
  2. 数值不稳定:在计算 时,当 很接近 0 时会导致数值不稳定问题。因此,常会加上一个小的平滑项避免计算误差。

6. 实际应用中的一些注意事项

  • 平滑技巧:在某些情况下,使用标签平滑(Label Smoothing)可以防止模型过拟合。例如,将真实标签从 1 和 0 平滑为 0.9 和 0.1,这在实际应用中可以提升模型的泛化能力。
  • softmax+交叉熵:交叉熵损失函数通常与 softmax 函数配合使用,因为 softmax 输出的概率分布更符合交叉熵的定义。
  • 类别不平衡问题:在类别不平衡的任务中,交叉熵损失可能会忽略小类别的损失,因此常采用加权交叉熵(Weighted Cross-Entropy)或调和采样方法。

7. 总结

交叉熵损失函数在分类问题中得到了广泛应用。通过衡量模型预测概率与真实分布的差异,交叉熵损失帮助模型逐步调整参数,最终实现最优的分类效果。


ps:torch.nn.CrossEntropyLoss内置了Softmax运算,而Softmax的作用是计算分类结果中最大的类,因此在使用torch.nn.CrossEntropyLoss作为损失函数时,不需要在网络层的最后添加Softmax层。


http://www.kler.cn/a/371927.html

相关文章:

  • 毕业项目推荐:基于yolov8/yolov5的行人检测识别系统(python+卷积神经网络)
  • STM32-BKP备份寄存器RTC实时时钟
  • RedisTemplate执行lua脚本及Lua 脚本语言详解
  • SpringMVC的消息转换器
  • 25/1/6 算法笔记<强化学习> 初玩V-REP
  • 现代前端框架
  • Django ORM 数据库管理 提高查询、更新性能的技巧和编程习惯:
  • ECharts 折线图 / 柱状图 ,通用配置标注示例
  • OpenCV基本操作(python开发)——(8)实现芯片瑕疵检测
  • 【GPT模型的大小】GPT3模型到底多大,如果训练需要什么条件?
  • 盘古信息IMS系统助力制造企业释放新质生产力
  • 上市公司数字经济与实体经济融合发展程度测算数据(2008-2022年)-最新出炉_附下载链接
  • 基于华为atlas环境下的OpenPose人体关键点检测的人员跨越、坐立检测
  • Mybatis-15.动态SQL-if
  • 【Hadoop之hdfs】hdfs一些简单明了的总结(一篇足以,字少但都是精华)
  • pytest 单元框架里,前置条件
  • MySQL数据集成至金蝶云星空的解决方案
  • 【Fastjson反序列化漏洞:深入了解与防范】
  • 类加载机制123
  • HTML入门教程9:HTML引用
  • java 大集合切分成一个集合中有多个小集合
  • Java程序设计基础 第十七章:反射和设计模式
  • 大话PM | 从项目管理软件看项目管理的三个原则两个思维两个工具
  • 深入 Prometheus 监控生态 - 第五篇:利用 API 信息进行监控(NAS 备份任务监控 + 解决思路)
  • 【约束优化】一次搞定拉格朗日,对偶问题,弱对偶定理,Slater条件和KKT条件
  • 画思维导图的app有哪些?5个软件让你轻松画思维导图不求人