当前位置: 首页 > article >正文

PyTorch gather 方法详解:作用、应用场景与示例解析(中英双语)

PyTorch gather 方法详解:作用、应用场景与示例解析

在深度学习和自然语言处理(NLP)任务中,我们经常需要从高维张量中提取特定索引的数据
PyTorch 提供的 torch.gather 方法可以高效地从张量的指定维度收集数据,广泛应用于语言模型(Transformer)、分类任务、强化学习等场景

在本文中,我们将详细介绍:

  • gather 方法的作用
  • 使用 gather 进行索引操作
  • gather 在 NLP 模型中的应用
  • gather 的计算效率与优化

1. torch.gather 的作用

1.1 gather 的基本用法

gather 允许我们在张量的指定维度上,按照给定的索引提取数据
其基本语法如下:

torch.gather(input, dim, index)
  • input:输入张量,形状为 (B, L, V)(可以是任意维度)。
  • dim:指定在哪个维度上收集数据(例如 dim=-1 代表在最后一个维度索引)。
  • index:索引张量,形状必须与 inputdim 之外的维度相同

1.2 gather 的核心逻辑

给定 inputindexgather 沿 dim 维度逐元素地获取 input 中指定索引位置的值


2. gather 的基础示例

2.1 从二维张量中提取元素

import torch

# 定义一个 3x4 的张量
input_tensor = torch.tensor([[10, 20, 30, 40], 
                             [50, 60, 70, 80], 
                             [90, 100, 110, 120]])

# 定义索引张量
index_tensor = torch.tensor([[0, 1], 
                             [2, 3], 
                             [1, 2]])

# 在 `dim=1` 维度上使用 gather
output = torch.gather(input_tensor, dim=1, index=index_tensor)
print(output)

输出:

tensor([[ 10,  20],
        [ 70,  80],
        [100, 110]])

解释:

  • input_tensor 形状为 (3,4),即 3 行 4 列
  • index_tensor 形状为 (3,2),其中的值指示要从 input_tensordim=1(列) 选取的数据:
    • 第 1 行:取 input_tensor[0,0]input_tensor[0,1],即 [10, 20]
    • 第 2 行:取 input_tensor[1,2]input_tensor[1,3],即 [70, 80]
    • 第 3 行:取 input_tensor[2,1]input_tensor[2,2],即 [100, 110]

3. gather 在 NLP 中的应用

3.1 计算 Token 的对数概率

在语言模型(如 Transformer)中,我们通常需要计算目标 token 的概率,即:
P ( y t ) = e logit y t ∑ e logit v P(y_t) = \frac{e^{\text{logit}_{y_t}}}{\sum e^{\text{logit}_{v}}} P(yt)=elogitvelogityt

其中:

  • logits 形状为 (B, L, V),表示 batch 里每个 token 对整个词表(vocabulary)中所有词的 logit 分数。
  • input_ids 形状为 (B, L),表示实际的 token 索引(即每个 token 在词表中的 ID)。

我们使用 gather 取出每个 input_idlogits 中对应的 logit 分值:

import torch

# 假设 batch_size=2, sequence_length=3, vocab_size=5
logits = torch.tensor([[[2.0, 1.0, 0.5, -1.0, 0.2], 
                         [0.1, -0.5, 2.2, 1.5, 0.0], 
                         [1.1, 3.5, 0.8, -0.2, -1.5]],

                        [[0.0, 2.3, -0.5, 1.0, 0.8], 
                         [-1.2, 1.7, 2.0, 0.3, -0.8], 
                         [2.5, -0.1, -1.2, 0.5, 3.0]]])

input_ids = torch.tensor([[0, 2, 1],  # 对应每个 token 在词表中的索引
                          [1, 3, 4]])

# 取出 input_ids 在 logits 中的 logit 值
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
print(token_logits)

输出:

tensor([[ 2.0000,  2.2000,  3.5000],
        [ 2.3000,  0.3000,  3.0000]])

解释:

  • logits.gather(dim=-1, index=input_ids.unsqueeze(-1))
    • dim=-1 代表从 Vocab 维度(最后一维)索引数据。
    • input_ids.unsqueeze(-1) 扩展维度,让 input_ids 形状变为 (B, L, 1),符合 gather 要求。
    • squeeze(-1) 还原到 (B, L) 形状,使结果是 每个 token 的 logit 值

4. gatherscatter 的对比

除了 gather(用于提取数据),PyTorch 还提供 scatter(用于写入数据)。

4.1 scatter 基本用法

import torch

# 初始化 3x3 零张量
x = torch.zeros(3, 3)

# 指定索引
index = torch.tensor([[0, 2], 
                      [1, 1], 
                      [2, 0]])

# 指定填充值
updates = torch.tensor([[5, 8], 
                         [3, 7], 
                         [6, 2]])

# 在 dim=1 维度上 scatter
x.scatter_(dim=1, index=index, src=updates)
print(x)

输出:

tensor([[5., 0., 8.],
        [0., 7., 0.],
        [2., 0., 6.]])

scatter_() 用于替换 index 位置的值,而 gather() 用于提取 index 位置的值


5. 总结

  • torch.gather(input, dim, index) 提取 input dim 维度上指定 index 位置的值
  • 常用于 NLP 任务中计算 token 对数概率、分类任务中提取预测分数
  • 通过 gather 提取 logits 对应 input_ids,可高效计算 对数概率损失函数
  • gather从索引获取数据,而 scatter根据索引写入数据

🚀 掌握 gather 让你在深度学习项目中更高效地处理索引操作! 🚀

深入理解 gather(dim=-1):作用、计算过程与 dim=-2 的对比

在 PyTorch 中,torch.gather 是一个强大的索引操作函数,它可以根据提供的 index 张量,从 input 张量的指定维度(dim)中提取相应的数据。
在 NLP(自然语言处理)任务中,我们常用 gather(dim=-1) 来从 logits 中获取 输入 token(input_ids)对应的 logits 值,用于计算损失或评估模型表现。


1. gather(dim=-1) 的作用

1.1 dim=-1 的含义

  • dim=-1 代表最后一个维度(即词表维度)。

  • logits.shape = (batch_size, sequence_length, vocab_size) 这样一个张量中:

    • dim=0:表示 batch 维度(不同样本)。
    • dim=1:表示 sequence 维度(句子中的不同 token)。
    • dim=2(即 dim=-1:表示词汇表(vocab),即每个 token 对所有单词的 logits 评分。
  • gather(dim=-1, index=input_ids.unsqueeze(-1)) 的作用:

    • dim=-1(词表维度)上提取 input_ids 对应的 logits 值
    • 这样,每个 token 只保留它对应的 logits,而不是整个词表的所有 logits。

2. 代码示例与计算过程

2.1 示例:计算 Token Logits

import torch

# 假设 batch_size=2, sequence_length=3, vocab_size=5
logits = torch.tensor([
    [[2.0, 1.0, 0.5, -1.0, 0.2], 
     [0.1, -0.5, 2.2, 1.5, 0.0], 
     [1.1, 3.5, 0.8, -0.2, -1.5]],

    [[0.0, 2.3, -0.5, 1.0, 0.8], 
     [-1.2, 1.7, 2.0, 0.3, -0.8], 
     [2.5, -0.1, -1.2, 0.5, 3.0]]
])

input_ids = torch.tensor([
    [0, 2, 1],  # 第一个样本的 token 索引
    [1, 3, 4]   # 第二个样本的 token 索引
])

# 扩展维度,使 input_ids 形状变为 (batch_size, sequence_length, 1)
expanded_index = input_ids.unsqueeze(-1)

# 使用 gather 从 logits 中提取相应的 token logits
token_logits = logits.gather(dim=-1, index=expanded_index).squeeze(-1)
print(token_logits)

输出:

tensor([[2.0000, 2.2000, 3.5000],
        [2.3000, 0.3000, 3.0000]])

2.2 gather(dim=-1) 计算过程解析

对于 logits.shape = (2, 3, 5)

  • dim=-1 代表最后一维,即 vocab_size=5 维度。
  • input_ids.shape = (2, 3),表示每个 batch 的 token 在词表中的索引。

让我们手动解析 gather(dim=-1) 的计算步骤:

BatchToken 索引 (dim=1)Index (dim=-1)Extracted Logit (gather 结果)
第 1 个样本第 1 个 tokeninput_ids[0,0] = 0logits[0,0,0] = 2.0
第 2 个 tokeninput_ids[0,1] = 2logits[0,1,2] = 2.2
第 3 个 tokeninput_ids[0,2] = 1logits[0,2,1] = 3.5
第 2 个样本第 1 个 tokeninput_ids[1,0] = 1logits[1,0,1] = 2.3
第 2 个 tokeninput_ids[1,1] = 3logits[1,1,3] = 0.3
第 3 个 tokeninput_ids[1,2] = 4logits[1,2,4] = 3.0

这正是 gather(dim=-1) 提取的值。


3. dim=-1 vs. dim=-2 的区别

3.1 什么是 dim=-2

如果改成 gather(dim=-2),它会尝试在 sequence 维度(dim=1 上进行索引,这会导致错误的行为。
因为 input_ids 只包含 token 在词表中的索引,而不是 token 在句子中的索引。

3.2 如果错误地使用 dim=-2

wrong_gather = logits.gather(dim=-2, index=input_ids.unsqueeze(-1))
print(wrong_gather.shape)
  • dim=-2 代表第二个维度(sequence 维度)
  • 这意味着 PyTorch 会尝试从 logits 中选取整个 token 级别的数据,而不是单独的 token logits

❌ 结果错误,因为 input_ids 里的索引根本不适用于 dim=-2

3.3 为什么 dim=-1 才是正确的?

  • input_ids 里的索引指向 词表索引(vocab index),所以应该沿着词表维度(dim=-1)索引数据。
  • dim=-1 选择的是 单个 token 对应的 logits,不会影响整个句子结构。

4. 结论

dim含义是否正确
dim=-1 (最后一维)提取每个 token 在词表中的 logits正确
dim=-2 (倒数第二维)尝试索引整个句子级别的数据错误

核心要点

dim=-1(最后一维)用于 获取输入 token 对应的 logits 值,常用于 NLP 任务。
gather(dim=-1, index=input_ids.unsqueeze(-1))input_ids 选择 logits 里的正确位置。
dim=-2 会错误地索引整个 token 级别的数据,而不是单个 token logits。

🚀 正确理解 gather(dim=-1),能够帮助你高效地提取模型输出,用于计算损失、评估模型! 🚀

Understanding torch.gather: Purpose, Use Cases, and Implementation in NLP

In deep learning, particularly in Natural Language Processing (NLP) and reinforcement learning, we often need to extract specific values from high-dimensional tensors using given indices. torch.gather is a powerful PyTorch function that efficiently retrieves data along a specified dimension based on an index tensor.

This article will cover:

  • The purpose of torch.gather
  • How gather works with examples
  • Practical applications in NLP and deep learning
  • Performance considerations and comparisons with scatter

1. What is torch.gather?

1.1 Basic Syntax

torch.gather(input, dim, index)
  • input: The source tensor from which values are gathered.
  • dim: The dimension along which to index values.
  • index: A tensor containing the indices of elements to extract.

1.2 How gather Works

  • It retrieves values from input at positions specified by index along the dim dimension.
  • The index tensor must have the same shape as input, except for the dim dimension.

2. Basic Examples of torch.gather

2.1 Extracting Elements from a 2D Tensor

import torch

# Define a 3x4 tensor
input_tensor = torch.tensor([[10, 20, 30, 40], 
                             [50, 60, 70, 80], 
                             [90, 100, 110, 120]])

# Define the index tensor
index_tensor = torch.tensor([[0, 1], 
                             [2, 3], 
                             [1, 2]])

# Gather values along dimension 1 (columns)
output = torch.gather(input_tensor, dim=1, index=index_tensor)
print(output)

Output:

tensor([[ 10,  20],
        [ 70,  80],
        [100, 110]])

Explanation:

  • dim=1 means we are indexing columns.
  • index_tensor[i, j] determines which element to select from input_tensor[i].

3. gather in NLP: Extracting Token Logits

3.1 Why Use gather in NLP?

In Transformer-based language models (GPT, BERT, etc.), we often need to compute the log probability of specific tokens. Given:

  • logits: The model’s output scores for each token.
  • input_ids: The actual token indices.

We use gather to efficiently retrieve the logits corresponding to each token.

3.2 Extracting Token Logits for Loss Calculation

import torch

# Simulated logits for batch_size=2, sequence_length=3, vocab_size=5
logits = torch.tensor([[[2.0, 1.0, 0.5, -1.0, 0.2], 
                         [0.1, -0.5, 2.2, 1.5, 0.0], 
                         [1.1, 3.5, 0.8, -0.2, -1.5]],

                        [[0.0, 2.3, -0.5, 1.0, 0.8], 
                         [-1.2, 1.7, 2.0, 0.3, -0.8], 
                         [2.5, -0.1, -1.2, 0.5, 3.0]]])

input_ids = torch.tensor([[0, 2, 1],  # Token indices
                          [1, 3, 4]])

# Extracting logits corresponding to input_ids
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
print(token_logits)

Output:

tensor([[ 2.0000,  2.2000,  3.5000],
        [ 2.3000,  0.3000,  3.0000]])

3.3 Explanation

  1. input_ids.unsqueeze(-1) converts shape (B, L)(B, L, 1), making it compatible with gather.
  2. gather(dim=-1, index=input_ids.unsqueeze(-1)) retrieves the logits corresponding to input_ids.
  3. squeeze(-1) removes the unnecessary last dimension.

This operation is efficient and memory-friendly compared to iterating through tokens manually.


4. gather vs. scatter

While gather retrieves values from an input tensor using an index, scatter does the opposite: it writes values to an output tensor at specific indices.

4.1 Using scatter to Modify a Tensor

import torch

# Initialize a 3x3 zero tensor
x = torch.zeros(3, 3)

# Define index positions
index = torch.tensor([[0, 2], 
                      [1, 1], 
                      [2, 0]])

# Define values to write
updates = torch.tensor([[5, 8], 
                         [3, 7], 
                         [6, 2]])

# Use scatter_ to update x
x.scatter_(dim=1, index=index, src=updates)
print(x)

Output:

tensor([[5., 0., 8.],
        [0., 7., 0.],
        [2., 0., 6.]])

4.2 Key Difference

  • gather: Extracts values from specific indices.
  • scatter: Writes values to specific indices.

5. Performance and Memory Efficiency

5.1 Why gather is Efficient?

  • Vectorized indexing: Instead of looping through individual indices, gather efficiently extracts multiple values in parallel.
  • Lower memory footprint: Since gather does not require additional tensor allocations, it is more memory-efficient than manually indexing with loops.
  • Optimized for GPU: PyTorch internally optimizes gather to run efficiently on CUDA devices.

5.2 Performance Benchmark

import time

x = torch.randn(1000, 1000)
index = torch.randint(0, 1000, (1000, 500))

start = time.time()
_ = x.gather(dim=1, index=index)
end = time.time()
print(f"gather execution time: {end - start:.6f} s")

Results (Example):

gather execution time: 0.002341 s

This is much faster than manually iterating over indices.


6. Summary

  • torch.gather(input, dim, index) efficiently extracts values from a tensor using an index tensor.
  • Common use cases:
    • Extracting token logits for NLP tasks (e.g., loss computation in Transformer models).
    • Indexing probability distributions in reinforcement learning.
    • Selecting specific elements from multi-dimensional tensors.
  • gather is memory-efficient, parallelized, and optimized for GPU acceleration.
  • Comparison with scatter:
    • gather extracts values from an input tensor.
    • scatter writes values into an output tensor.

🚀 Mastering torch.gather will help you write more efficient deep learning models! 🚀

后记

2025年2月21日19点14分于上海,在GPT4o大模型辅助下完成。


http://www.kler.cn/a/556624.html

相关文章:

  • 华为云ECS命名规则解析与规格选型实战指南
  • 利用 OpenCV 进行棋盘检测与透视变换
  • 算法:选择排序(以排队为例)
  • Linux 内核网络设备驱动编程:私有协议支持
  • wps中zotero插件消失,解决每次都需要重新开问题
  • Liunx(CentOS-6-x86_64)系统安装MySql(5.6.50)
  • vue3项目axios最简单封装 - ajax请求封装
  • 51单片机入门_10_数码管动态显示(数字的使用;简单动态显示;指定值的数码管动态显示)
  • 低代码技术在医院的应用与思考
  • 计算机专业知识【深入理解子网中的特殊地址:为何 192.168.0.1 和 192.168.0.255 不能随意分配】
  • AI汽车新风向:「死磕」AI底盘,引爆线控底盘新增长拐点
  • RTSP场景下RTP协议详解及音视频打包全流程
  • 如何在 ubuntu 上使用 Clash 与 docker 开启代理拉起
  • python制图之小提琴图
  • webpack和grunt以及gulp有什么不同?
  • Github 2025-02-20 Go开源项目日报 Top10
  • linux 安装启动zookeeper全过程及遇到的坑
  • Qt/C++面试【速通笔记一】
  • 蓝桥杯备赛-基础训练(一)数组 day13
  • [文末数据集]ML.NET库学习010:URL是否具有恶意性分类