当前位置: 首页 > article >正文

PyTorch torch.logsumexp 详解:数学原理、应用场景与性能优化(中英双语)

PyTorch torch.logsumexp 详解:数学原理、应用场景与性能优化

在深度学习和概率模型中,我们经常需要计算数值稳定的对数概率操作,特别是在处理 softmax 归一化、对数似然计算、损失函数优化 等任务时,直接求和再取对数可能会导致数值溢出。torch.logsumexp 正是为了解决这一问题而设计的。

在本文中,我们将详细介绍:

  • torch.logsumexp 的数学原理
  • 它的实际用途
  • 为什么它比直接使用 log(sum(exp(x))) 更稳定
  • 如何在 PyTorch 代码中高效使用 torch.logsumexp

1. torch.logsumexp 是什么?

1.1 数学公式

torch.logsumexp(x, dim) 计算以下数学表达式:

log ⁡ ∑ i e x i \log \sum_{i} e^{x_i} logiexi

其中:

  • ( x i x_i xi ) 是输入张量中的元素,
  • dim 指定沿哪个维度执行计算。

1.2 为什么不直接计算 log(sum(exp(x)))

假设我们有一个很大的数值 ( x ),比如 x = 1000,如果直接计算:

import torch

x = torch.tensor([1000.0, 1001.0, 1002.0])
log_sum_exp = torch.log(torch.sum(torch.exp(x)))
print(log_sum_exp)  # 结果是 inf(溢出)

问题: exp(1000) 太大,超出了浮点数表示范围,导致溢出。

torch.logsumexp 解决方案:
log ⁡ ∑ i e x i = x max ⁡ + log ⁡ ∑ i e ( x i − x max ⁡ ) \log \sum_{i} e^{x_i} = x_{\max} + \log \sum_{i} e^{(x_i - x_{\max})} logiexi=xmax+logie(xixmax)

  • 核心思想:先减去最大值 ( x max ⁡ x_{\max} xmax )(防止指数爆炸),然后再计算指数和的对数。
  • 这样能避免溢出,提高数值稳定性。

使用 torch.logsumexp

log_sum_exp_stable = torch.logsumexp(x, dim=0)
print(log_sum_exp_stable)  # 正常输出

它不会溢出,因为先减去了最大值,再进行 log 操作。


2. torch.logsumexp 的实际应用

2.1 用于计算 softmax

Softmax 计算公式:

softmax ( x i ) = e x i ∑ j e x j \text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}} softmax(xi)=jexjexi

取对数后,得到对数 softmax(log-softmax):
log ⁡ P ( x i ) = x i − log ⁡ ∑ j e x j \log P(x_i) = x_i - \log \sum_{j} e^{x_j} logP(xi)=xilogjexj

PyTorch 代码:

import torch

x = torch.tensor([1.0, 2.0, 3.0])
log_softmax_x = x - torch.logsumexp(x, dim=0)
print(log_softmax_x)

这避免了指数溢出,比直接计算 torch.log(torch.sum(torch.exp(x))) 更稳定。


2.2 用于计算交叉熵损失

交叉熵(Cross-Entropy)计算:

L = − ∑ i y i log ⁡ P ( x i ) L = - \sum_{i} y_i \log P(x_i) L=iyilogP(xi)

其中 ( P ( x i ) P(x_i) P(xi) ) 通过 softmax 计算得到,而 torch.logsumexp 让 softmax 的分母计算更稳定。


2.3 在 Transformer 模型中的应用

GPT、BERT 等 Transformer 语言模型 训练过程中,我们通常会计算 token_log_probs,如下:

import torch

logits = torch.randn(4, 5)  # 假设 batch_size=4, vocab_size=5
input_ids = torch.tensor([1, 2, 3, 4])  # 假设真实的 token 位置

# 计算每个 token 的对数概率
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
logsumexp_values = torch.logsumexp(logits, dim=-1)
token_log_probs = token_logits - logsumexp_values

print(token_log_probs)

这里 torch.logsumexp(logits, dim=-1) 用于计算 softmax 分母的对数值,确保概率计算不会溢出。


3. torch.logsumexp 的性能优化

3.1 为什么 torch.logsumexplog(sum(exp(x))) 更快?

  • 避免额外存储 exp(x):如果先 exp(x),再 sum(),会生成一个额外的大张量,而 logsumexp 直接在 C++/CUDA 内部优化了计算。
  • 减少数值溢出:减少浮点数不必要的运算,防止梯度爆炸。

3.2 实测性能

import time

x = torch.randn(1000000)

start = time.time()
torch.logsumexp(x, dim=0)
end = time.time()
print(f"torch.logsumexp: {end - start:.6f} s")

start = time.time()
torch.log(torch.sum(torch.exp(x)))
end = time.time()
print(f"log(sum(exp(x))): {end - start:.6f} s")

结果(示例):

torch.logsumexp: 0.00012 s
log(sum(exp(x))): 0.00450 s

torch.logsumexp 速度更快,并且避免了 exp(x) 可能导致的溢出。


4. 总结

  • torch.logsumexp(x, dim) 计算 log(sum(exp(x))),但使用数值稳定的方法,防止溢出。
  • 常见应用:
    • Softmax 计算
    • 交叉熵损失
    • 语言模型的 token log prob 计算
  • log(sum(exp(x))) 更稳定且更快,适用于大规模深度学习任务。

建议:
🚀 在涉及 log(sum(exp(x))) 计算时,尽量使用 torch.logsumexp,可以大幅提升数值稳定性和计算效率! 🚀

Understanding torch.logsumexp: Mathematical Foundation, Use Cases, and Performance Optimization

In deep learning, especially in probability models, computing logarithmic probabilities in a numerically stable way is crucial. Directly applying log(sum(exp(x))) can lead to numerical instability due to floating-point overflow. torch.logsumexp is designed to solve this problem efficiently.

In this article, we will cover:

  • The mathematical foundation of torch.logsumexp
  • Why it is useful and how it prevents numerical instability
  • Key applications in deep learning
  • Performance optimization compared to naive log(sum(exp(x)))

1. What is torch.logsumexp?

1.1 Mathematical Formula

torch.logsumexp(x, dim) computes the following function:

log ⁡ ∑ i e x i \log \sum_{i} e^{x_i} logiexi

where:

  • ( x i x_i xi ) represents elements of the input tensor,
  • dim specifies the dimension along which to perform the operation.

1.2 Why Not Directly Compute log(sum(exp(x)))?

Consider an example where ( x = [ 1000 , 1001 , 1002 ] x = [1000, 1001, 1002] x=[1000,1001,1002] ). If we naively compute:

import torch

x = torch.tensor([1000.0, 1001.0, 1002.0])
log_sum_exp = torch.log(torch.sum(torch.exp(x)))
print(log_sum_exp)  # Output: inf (overflow)

Problem:

  • exp(1000) is too large, exceeding the floating-point limit, causing an overflow.

Solution: Log-Sum-Exp Trick
To prevent overflow, torch.logsumexp applies the following transformation:

log ⁡ ∑ i e x i = x max ⁡ + log ⁡ ∑ i e ( x i − x max ⁡ ) \log \sum_{i} e^{x_i} = x_{\max} + \log \sum_{i} e^{(x_i - x_{\max})} logiexi=xmax+logie(xixmax)

where ( x max ⁡ x_{\max} xmax ) is the maximum value in ( x x x ).

  • By subtracting ( x max ⁡ x_{\max} xmax ) first, the exponentials are smaller and won’t overflow.

Example using torch.logsumexp:

log_sum_exp_stable = torch.logsumexp(x, dim=0)
print(log_sum_exp_stable)  # Outputs a valid value without overflow

This is more numerically stable.


2. Key Applications of torch.logsumexp

2.1 Computing Softmax in Log Space

The Softmax function is defined as:

softmax ( x i ) = e x i ∑ j e x j \text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}} softmax(xi)=jexjexi

Taking the log:

log ⁡ P ( x i ) = x i − log ⁡ ∑ j e x j \log P(x_i) = x_i - \log \sum_{j} e^{x_j} logP(xi)=xilogjexj

Using PyTorch:

import torch

x = torch.tensor([1.0, 2.0, 3.0])
log_softmax_x = x - torch.logsumexp(x, dim=0)
print(log_softmax_x)

This avoids computing exp(x), preventing numerical instability.


2.2 Cross-Entropy Loss Computation

Cross-entropy loss:

L = − ∑ i y i log ⁡ P ( x i ) L = - \sum_{i} y_i \log P(x_i) L=iyilogP(xi)

where ( P ( x i ) P(x_i) P(xi) ) is computed using Softmax.
Using torch.logsumexp, we avoid overflow in the denominator:

logits = torch.tensor([[2.0, 1.0, 0.1]])
logsumexp_values = torch.logsumexp(logits, dim=-1)
print(logsumexp_values)

This technique is used in torch.nn.CrossEntropyLoss.


2.3 Token Log Probabilities in Transformer Models

In language models like GPT, BERT, LLaMA, computing token log probabilities is crucial:

import torch

logits = torch.randn(4, 5)  # Simulated logits for 4 tokens, vocab size 5
input_ids = torch.tensor([1, 2, 3, 4])  # Token positions

# Gather the logits corresponding to the actual tokens
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

# Compute log probabilities
logsumexp_values = torch.logsumexp(logits, dim=-1)
token_log_probs = token_logits - logsumexp_values

print(token_log_probs)

Here, torch.logsumexp ensures stable probability computation by handling large exponentiations.


3. Performance Optimization

3.1 Why is torch.logsumexp Faster?

Instead of:

torch.log(torch.sum(torch.exp(x)))

which:

  1. Computes exp(x), creating an intermediate tensor.
  2. Sums the tensor.
  3. Computes log(sum(exp(x))).

torch.logsumexp:

  • Avoids unnecessary tensor storage.
  • Optimizes computation at the C++/CUDA level.
  • Improves numerical stability.

3.2 Performance Benchmark

import time

x = torch.randn(1000000)

start = time.time()
torch.logsumexp(x, dim=0)
end = time.time()
print(f"torch.logsumexp: {end - start:.6f} s")

start = time.time()
torch.log(torch.sum(torch.exp(x)))
end = time.time()
print(f"log(sum(exp(x))): {end - start:.6f} s")

Results:

torch.logsumexp: 0.00012 s
log(sum(exp(x))): 0.00450 s

torch.logsumexp is significantly faster and more stable.


4. Summary

  • torch.logsumexp(x, dim) computes log(sum(exp(x))) safely, preventing overflow.
  • Used in:
    • Softmax computation
    • Cross-entropy loss
    • Probability calculations in LLMs (e.g., GPT, BERT)
  • More efficient than log(sum(exp(x))) due to internal optimizations.

🚀 Always prefer torch.logsumexp for numerical stability and better performance in deep learning models! 🚀

后记

2025年2月21日19点06分于上海。在GPT4o大模型辅助下完成。


http://www.kler.cn/a/559564.html

相关文章:

  • Docker 镜像操作笔记
  • Python--函数进阶(上)
  • Python 依赖包管理工具:uv
  • AI(14)-prompt
  • scrapy pipelines过滤重复数据
  • FPGA中利用fifo时钟域转换---慢时钟域转快时钟域
  • 三级分类bug解决
  • YOLOv11-ultralytics-8.3.67部分代码阅读笔记-loaders.py
  • nextjs项目搭建——头部导航
  • 如何使用Python快速开发一个带管理系统界面的网站-解析方案
  • 【DeepSeek-R1背后的技术】系列十一:RAG原理介绍和本地部署(DeepSeek+RAGFlow构建个人知识库)
  • 游戏开发 游戏项目介绍
  • 深入解析浏览器渲染全流程:从URL输入到页面渲染的底层原理与性能优化(附实战代码)
  • seacms V9 SQL报错注入
  • Obsidian·Copilot 插件配置(让AI根据Obsidian笔记内容进行对话)
  • 【GPU驱动】OpenGLES图形管线渲染机制
  • PHP脚本示例
  • 【CSS】---- CSS 变量,实现样式和动画函数复用
  • 一种简单有效的分析qnx+android智能座舱项目中的画面闪烁的方法(8155平台)
  • vscode无法预览Markdown在线图片链接