当前位置: 首页 > article >正文

【NLP 50、损失函数 KL散度】

目录

一、定义与公式

1.核心定义

2.数学公式

3.KL散度与交叉熵的关系

二、使用场景

1.生成模型与变分推断

2.知识蒸馏

3.模型评估与优化

4.信息论与编码优化

三、原理与特性

1.信息论视角

​2.优化目标

3.​局限性

四、代码示例

代码运行流程

核心代码解析


抵达梦想靠的不是狂热的想象,而是谦卑的务实,甚至你自己都看不起的可怜的隐忍

                                                                                                                        —— 25.3.27

一、定义与公式

1.核心定义

        KL散度(相对熵)是衡量两个概率分布 P 和 Q 之间差异的非对称性指标。它量化了当用分布 Q 近似真实分布 P 时的信息损失

非对称性:,即P和Q的顺序不能交换

非负性:,当且仅当P = Q时取等号

2.数学公式

离散形式:

连续形式:

其中,P是真实分布,Q是近似分布

3.KL散度与交叉熵的关系

KL散度可以分解为交叉熵H(P,Q与P的熵H(P):

交叉熵常用于分类任务,而KL散度更关注分布间的信息差异


二、使用场景

1.生成模型与变分推断

变分自编码器(VAE)​:通过最小化,使编码器输出的隐变量分布Q(z|x)逼近先验分布P(z)

生成对抗网络(GAN)​:辅助衡量生成分布与真实分布的差异

2.知识蒸馏

        将复杂教师模型的输出概率(软标签)作为监督信号,指导学生模型学习,损失函数中常包含KL散度项

3.模型评估与优化

多模态分布对齐:在推荐系统中对齐用户行为分布与模型预测分布​

异常检测:通过KL散度衡量测试数据分布与正常数据分布的偏离程度

4.信息论与编码优化

最小化编码长度:KL散度表示用 Q 编码 P 时所需的额外比特数


三、原理与特性

1.信息论视角

信息增益:KL散度表示从 Q 中获取 P 的信息时需要增加的“惊讶度”(Surprisal)。

凸性:KL散度是凸函数,可通过梯度下降法优化。

​2.优化目标

前向KL散度 DKL​(P∥Q):要求 Q 覆盖 P 的主要模式,避免 Q 的“零概率陷阱”(即 Q(x)=0 但 P(x)>0 会导致无穷大)​反向KL散度 DKL​(Q∥P):鼓励 Q 聚焦于 P 的单一主峰,适用于稀疏分布近似。

3.​局限性

非对称性:需根据任务选择方向(如VAE使用前向KL,部分GAN变体使用反向KL)

数值稳定性:需避免 Q(x)=0 或极端概率值,可通过平滑或温度参数(Temperature Scaling)调整。


四、代码示例

代码运行流程

KL散度计算流程
├── 1. 输入预处理
│   ├── a. 获取学生/教师模型原始输出
│   │   ├─ student_logits: 形状(batch=32, classes=10)
│   │   └─ teacher_logits: 同左[1,3](@ref)
│   └── b. 温度参数初始化
│       └─ temperature=5.0 (默认值)
├── 2. 概率变换
│   ├── a. 温度缩放
│   │   ├─ student_logits → student_logits / 5.0
│   │   └─ teacher_logits → teacher_logits / 5.0
│   ├── b. 概率归一化
│   │   ├─ student_probs = log_softmax(...)  # 对数空间
│   │   └─ teacher_probs = softmax(...)      # 线性空间
├── 3. 损失计算
│   ├── a. 初始化KLDivLoss
│   │   └─ reduction='batchmean' (符合数学期望)
│   ├── b. 执行KL散度计算
│   │   └─ KL(student_probs || teacher_probs)
│   └── c. 梯度补偿
│       └─ 乘以temperature²=25 恢复梯度幅值
└── 4. 结果输出
    └── 打印损失值 (标量Tensor转float)

student_logits:学生模型的原始输出(未归一化),形状为 (batch_size, num_classes),表示每个样本的预测得分

teacher_logits:教师模型的原始输出(未归一化),作为知识蒸馏的监督信号,形状同student_logits

temperature:温度缩放参数,软化概率分布(值越大分布越平滑,值越小越接近原始分布)

student_probs:学生模型经温度缩放后的对数概率

teacher_probs:教师模型经温度缩放后的概率

loss:KL散度损失的计算结果,表示学生模型输出分布与教师模型输出分布之间的差异程度。该值是一个标量(Scalar),用于指导反向传播优化学生模型的参数

batch_size:表示 ​单次输入模型的样本数量,即一次前向传播和反向传播处理32个样本。

nums_classes: 表示 ​分类任务的类别总数,即模型需区分的不同标签种类数。

F.log_softmax():将输入张量通过Softmax函数归一化为概率分布后,再对每个元素取自然对数,常用于分类任务的损失计算(如交叉熵损失)。

参数名类型说明默认值
​**input**Tensor输入张量必填
​**dim**int指定归一化的维度(如dim=1表示按行计算)必填

F.softmax():将输入张量通过指数函数归一化为概率分布,输出值范围为(0,1)且和为1。

参数名类型说明默认值
​**input**Tensor输入张量必填
​**dim**int归一化维度(如dim=0按列归一化)必填

nn.KLDivLoss():计算两个概率分布之间的Kullback-Leibler散度(KL散度),用于衡量分布差异。

参数名类型说明可选值默认值
​**reduction**str损失聚合方式'none''mean''sum''batchmean''mean'

torch.randn():生成服从标准正态分布(均值为0,标准差为1)的随机数张量,常用于初始化权重或生成噪声数据。

参数名类型说明默认值
​***size**int或tuple张量形状(如(3,4)生成3行4列矩阵)必填
​**dtype**torch.dtype数据类型(如torch.float32None(自动推断)
​**device**torch.device设备(如'cuda'CPU
​**requires_grad**bool是否需要梯度跟踪False

item():PyTorch中torch.Tensor类的方法,用于从单元素张量中提取Python标量值(如intfloat等)

核心代码解析

loss = nn.KLDivLoss(reduction='batchmean')(student_probs, teacher_probs) * (temperature ​** 2)

nn.KLDivLoss(reduction='batchmean')计算学生模型输出 (student_probs) 与教师模型输出 (teacher_probs) 之间的 ​KL散度,衡量两者的概率分布差异

        ​参数 reduction='batchmean'将每个样本的KL散度求和后除以批量大小 (batch_size),确保损失值符合KL散度的数学定义

   mean对所有元素取平均(总和除以元素总数)。

   sum直接求和。

   none保留每个样本的独立损失值。

(student_probs, teacher_probs):输入参数student_probs 和 teacher_probs

* (temperature ​** 2):温度缩放与梯度补偿

        温度的作用:软化概率分布:高温值会使教师模型的概率分布更平滑,避免过度关注高置信度类别

        ​为何乘以 temperature²① 梯度补偿:温度缩放会缩小梯度的幅值,乘以 temperature² 可恢复原始梯度量级,确保优化方向正确​ ② 数学推导:KL散度计算中,温度参数会引入缩放因子 T1​,反向传播时梯度需乘以 T2 以抵消缩放效应。

import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义KL散度损失函数(带温度参数)
def kl_div_loss_with_temperature(student_logits, teacher_logits, temperature=5.0):
    # 对logits应用温度缩放
    student_probs = F.log_softmax(student_logits / temperature, dim=-1)
    teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
    # 计算KL散度
    loss = nn.KLDivLoss(reduction='batchmean')(student_probs, teacher_probs) * (temperature ​** 2)
    return loss

# 模拟输入数据
batch_size, num_classes = 32, 10
student_logits = torch.randn(batch_size, num_classes)  # 学生模型输出(未归一化)
teacher_logits = torch.randn(batch_size, num_classes)  # 教师模型输出(未归一化)

# 计算损失
loss = kl_div_loss_with_temperature(student_logits, teacher_logits)
print(f"KL散度损失: {loss.item()}")


http://www.kler.cn/a/613546.html

相关文章:

  • Unity 简单使用Addressables加载SpriteAtlas图集资源
  • java使用aspose添加多个图片到word
  • 3.27-1 pymysql下载及使用
  • Stable Diffusion 基础模型结构超级详解!
  • 用 pytorch 从零开始创建大语言模型(七):根据指示进行微调
  • TextGrad:案例
  • 横扫SQL面试——事件流处理(峰值统计)问题
  • SDL —— 将sdl渲染画面嵌入Qt窗口显示(附:源码)
  • CSS回顾-Flex弹性盒布局
  • Vue $bus被多次触发
  • 【WPF】ListView数据绑定
  • 【AI工具开发】Notepad++插件开发实践:从基础交互到ScintillaCall集成
  • C语言之链表
  • 分布式光伏防逆流如何实现?
  • 每日免费分享之精品wordpress主题系列~DAY16
  • 云原生四重涅槃·破镜篇:混沌工程证道心,九阳真火锻金身
  • 可视化图解算法:递归基础
  • Pyside6介绍和开发第一个程序
  • GPT4o漫画制作(小白教程)
  • 后端开发中的文件上传的实现