当前位置: 首页 > article >正文

深度学习|表示学习|作为损失函数的交叉熵|04

如是我闻: Cross Entropy(交叉熵)是一种用于衡量两个概率分布之间差异的度量方式,被广泛的应用,但是总是记不住。

在这里插入图片描述


背后的理论

交叉熵源于信息论,它是描述两个概率分布之间相似度的重要工具。假设有两个概率分布:

  • p ( x ) p(x) p(x) : 真实分布(通常是one-hot编码的标签)
  • q ( x ) q(x) q(x) : 预测分布(通常是模型的输出,经过Softmax)

交叉熵衡量的是使用预测分布 q ( x ) q(x) q(x) 去表达真实分布 p ( x ) p(x) p(x) 所需的信息量(或不确定性)。

交叉熵的公式为:
H ( p , q ) = − ∑ x p ( x ) log ⁡ q ( x ) H(p, q) = - \sum_{x} p(x) \log q(x) H(p,q)=xp(x)logq(x)

在分类任务中,通常 p ( x ) p(x) p(x) 是one-hot编码的概率分布,例如对于分类为类别 i i i,只有 p ( x i ) = 1 p(x_i) = 1 p(xi)=1,其余为0。此时,交叉熵简化为:

H ( p , q ) = − log ⁡ q ( x i ) H(p, q) = - \log q(x_i) H(p,q)=logq(xi)

其中 q ( x i ) q(x_i) q(xi) 是模型预测类别 i i i 的概率。


在机器学习中的应用

在神经网络中,交叉熵通常作为分类问题的损失函数。例如:

  1. 多分类任务(Softmax 输出):模型的输出是类别的概率分布。

    • 损失函数为:
      L = − 1 N ∑ i = 1 N ∑ j = 1 C y i j log ⁡ ( y ^ i j ) L = - \frac{1}{N} \sum_{i=1}^N \sum_{j=1}^C y_{ij} \log(\hat{y}_{ij}) L=N1i=1Nj=1Cyijlog(y^ij)
      其中:
      • N N N: 样本数量
      • C C C: 类别数量
      • y i j y_{ij} yij: 第 i i i 个样本的真实标签(one-hot编码)
      • y ^ i j \hat{y}_{ij} y^ij: 第 i i i 个样本预测为类别 j j j 的概率
  2. 二分类任务(Sigmoid 输出):用于只有两个类别的场景。

    • 损失函数为:
      L = − 1 N ∑ i = 1 N ( y i log ⁡ ( y ^ i ) + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) ) L = - \frac{1}{N} \sum_{i=1}^N \left( y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) \right) L=N1i=1N(yilog(y^i)+(1yi)log(1y^i))
      其中:
      • y i y_i yi: 第 i i i 个样本的真实标签(0 或 1)
      • y ^ i \hat{y}_i y^i: 第 i i i个样本的预测概率

为什么使用 Cross Entropy?

  1. 惩罚错误预测

    • 如果预测结果 q ( x i ) q(x_i) q(xi) 接近真实值 p ( x i ) p(x_i) p(xi),交叉熵值很低。
    • 如果预测值远离真实值,交叉熵值较高,模型会被更强烈地惩罚。
  2. 理论基础

    • 交叉熵直接源自最大似然估计(MLE)。通过最小化交叉熵,实际上是在最大化模型对训练数据的似然。
  3. 概率解释

    • 交叉熵可以理解为用预测分布 q ( x ) q(x) q(x) 表达真实分布 p ( x ) p(x) p(x) 的代价,反映了预测分布的准确性。

举例说明

假设我们有 3 类分类任务,真实标签为 [ 1 , 0 , 0 ] [1, 0, 0] [1,0,0](表示类别 1),模型输出的概率分布为 [ 0.7 , 0.2 , 0.1 ] [0.7, 0.2, 0.1] [0.7,0.2,0.1]

交叉熵计算:
H ( p , q ) = − ( 1 ⋅ log ⁡ ( 0.7 ) + 0 ⋅ log ⁡ ( 0.2 ) + 0 ⋅ log ⁡ ( 0.1 ) ) H(p, q) = - \left( 1 \cdot \log(0.7) + 0 \cdot \log(0.2) + 0 \cdot \log(0.1) \right) H(p,q)=(1log(0.7)+0log(0.2)+0log(0.1))
H ( p , q ) = − log ⁡ ( 0.7 ) ≈ 0.356 H(p, q) = - \log(0.7) \approx 0.356 H(p,q)=log(0.7)0.356

如果模型输出错误的概率分布 [ 0.1 , 0.3 , 0.6 ] [0.1, 0.3, 0.6] [0.1,0.3,0.6],交叉熵将变大:
H ( p , q ) = − log ⁡ ( 0.1 ) ≈ 2.302 H(p, q) = - \log(0.1) \approx 2.302 H(p,q)=log(0.1)2.302

从中可以看出:交叉熵随着预测接近真实分布而减小。


以上


http://www.kler.cn/a/505451.html

相关文章:

  • (三)html2canvas将HTML 转为图片并实现下载
  • VM(虚拟机)和Linux的安装
  • 蓝桥杯备赛:顺序表和单链表相关算法题详解(上)
  • 爬虫请求失败时如何处理?
  • JVM之垃圾回收器ZGC概述以及垃圾回收器总结的详细解析
  • 使用docker-compose安装Redis的主从+哨兵模式
  • 单片机存储器和C程序编译过程
  • vue3封装el-tour漫游式引导
  • 09.VSCODE:安装 Git for Windows
  • .NetCore 使用 NPOI 读取带有图片的excel数据
  • 软件测试 —— Selenium(等待)
  • 物联网云平台:智能硬件芯片 esp32 的开放式管理设计
  • 【Elasticsearch复合查询】
  • 基于spingboot+html技术的博客网站
  • 1.1.1 C语言常用的一些函数(持续更新)
  • 最好用的图文识别OCR -- PaddleOCR(4) 模型微调
  • 【JVM中的三色标记法是什么?】
  • ssh, git 配置多对公私钥
  • 简识MySQL中ReadView、RC、RR的关系
  • 二级缓存(缓存到Redis)
  • Electron 开发者的 Tauri 2.0 实战指南:文件系统操作
  • LeetCode热题100(三十四) —— 23.合并K个升序链表
  • git报错处理
  • linux服务器 常用脚本(超全)
  • SpringBoot项目中解决CORS跨域资源共享问题
  • 比较分析:Windsurf、Cody、Cline、Roo Cline、Copilot 和 通义灵码