PyTorch gather 方法详解:作用、应用场景与示例解析(中英双语)
PyTorch gather
方法详解:作用、应用场景与示例解析
在深度学习和自然语言处理(NLP)任务中,我们经常需要从高维张量中提取特定索引的数据。
PyTorch 提供的 torch.gather
方法可以高效地从张量的指定维度收集数据,广泛应用于语言模型(Transformer)、分类任务、强化学习等场景。
在本文中,我们将详细介绍:
gather
方法的作用- 使用
gather
进行索引操作 gather
在 NLP 模型中的应用gather
的计算效率与优化
1. torch.gather
的作用
1.1 gather
的基本用法
gather
允许我们在张量的指定维度上,按照给定的索引提取数据。
其基本语法如下:
torch.gather(input, dim, index)
input
:输入张量,形状为(B, L, V)
(可以是任意维度)。dim
:指定在哪个维度上收集数据(例如dim=-1
代表在最后一个维度索引)。index
:索引张量,形状必须与input
在dim
之外的维度相同。
1.2 gather
的核心逻辑
给定 input
和 index
,gather
沿 dim
维度逐元素地获取 input
中指定索引位置的值。
2. gather
的基础示例
2.1 从二维张量中提取元素
import torch
# 定义一个 3x4 的张量
input_tensor = torch.tensor([[10, 20, 30, 40],
[50, 60, 70, 80],
[90, 100, 110, 120]])
# 定义索引张量
index_tensor = torch.tensor([[0, 1],
[2, 3],
[1, 2]])
# 在 `dim=1` 维度上使用 gather
output = torch.gather(input_tensor, dim=1, index=index_tensor)
print(output)
输出:
tensor([[ 10, 20],
[ 70, 80],
[100, 110]])
解释:
input_tensor
形状为(3,4)
,即 3 行 4 列。index_tensor
形状为(3,2)
,其中的值指示要从input_tensor
的 dim=1(列) 选取的数据:- 第 1 行:取
input_tensor[0,0]
和input_tensor[0,1]
,即[10, 20]
- 第 2 行:取
input_tensor[1,2]
和input_tensor[1,3]
,即[70, 80]
- 第 3 行:取
input_tensor[2,1]
和input_tensor[2,2]
,即[100, 110]
- 第 1 行:取
3. gather
在 NLP 中的应用
3.1 计算 Token 的对数概率
在语言模型(如 Transformer)中,我们通常需要计算目标 token 的概率,即:
P
(
y
t
)
=
e
logit
y
t
∑
e
logit
v
P(y_t) = \frac{e^{\text{logit}_{y_t}}}{\sum e^{\text{logit}_{v}}}
P(yt)=∑elogitvelogityt
其中:
logits
形状为(B, L, V)
,表示 batch 里每个 token 对整个词表(vocabulary)中所有词的 logit 分数。input_ids
形状为(B, L)
,表示实际的 token 索引(即每个 token 在词表中的 ID)。
我们使用 gather
取出每个 input_id
在 logits
中对应的 logit 分值:
import torch
# 假设 batch_size=2, sequence_length=3, vocab_size=5
logits = torch.tensor([[[2.0, 1.0, 0.5, -1.0, 0.2],
[0.1, -0.5, 2.2, 1.5, 0.0],
[1.1, 3.5, 0.8, -0.2, -1.5]],
[[0.0, 2.3, -0.5, 1.0, 0.8],
[-1.2, 1.7, 2.0, 0.3, -0.8],
[2.5, -0.1, -1.2, 0.5, 3.0]]])
input_ids = torch.tensor([[0, 2, 1], # 对应每个 token 在词表中的索引
[1, 3, 4]])
# 取出 input_ids 在 logits 中的 logit 值
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
print(token_logits)
输出:
tensor([[ 2.0000, 2.2000, 3.5000],
[ 2.3000, 0.3000, 3.0000]])
解释:
logits.gather(dim=-1, index=input_ids.unsqueeze(-1))
:dim=-1
代表从Vocab
维度(最后一维)索引数据。input_ids.unsqueeze(-1)
扩展维度,让input_ids
形状变为(B, L, 1)
,符合gather
要求。squeeze(-1)
还原到(B, L)
形状,使结果是 每个 token 的 logit 值。
4. gather
与 scatter
的对比
除了 gather
(用于提取数据),PyTorch 还提供 scatter
(用于写入数据)。
4.1 scatter
基本用法
import torch
# 初始化 3x3 零张量
x = torch.zeros(3, 3)
# 指定索引
index = torch.tensor([[0, 2],
[1, 1],
[2, 0]])
# 指定填充值
updates = torch.tensor([[5, 8],
[3, 7],
[6, 2]])
# 在 dim=1 维度上 scatter
x.scatter_(dim=1, index=index, src=updates)
print(x)
输出:
tensor([[5., 0., 8.],
[0., 7., 0.],
[2., 0., 6.]])
scatter_()
用于替换 index
位置的值,而 gather()
用于提取 index
位置的值。
5. 总结
torch.gather(input, dim, index)
提取input
在dim
维度上指定index
位置的值。- 常用于 NLP 任务中计算 token 对数概率、分类任务中提取预测分数。
- 通过
gather
提取logits
对应input_ids
,可高效计算 对数概率 和 损失函数。 gather
是 从索引获取数据,而scatter
是 根据索引写入数据。
🚀 掌握 gather
让你在深度学习项目中更高效地处理索引操作! 🚀
深入理解 gather(dim=-1)
:作用、计算过程与 dim=-2
的对比
在 PyTorch 中,torch.gather
是一个强大的索引操作函数,它可以根据提供的 index
张量,从 input
张量的指定维度(dim
)中提取相应的数据。
在 NLP(自然语言处理)任务中,我们常用 gather(dim=-1)
来从 logits
中获取 输入 token(input_ids
)对应的 logits 值,用于计算损失或评估模型表现。
1. gather(dim=-1)
的作用
1.1 dim=-1
的含义
-
dim=-1
代表最后一个维度(即词表维度)。 -
在
logits.shape = (batch_size, sequence_length, vocab_size)
这样一个张量中:dim=0
:表示 batch 维度(不同样本)。dim=1
:表示 sequence 维度(句子中的不同 token)。dim=2
(即dim=-1
):表示词汇表(vocab),即每个 token 对所有单词的 logits 评分。
-
gather(dim=-1, index=input_ids.unsqueeze(-1))
的作用:- 在
dim=-1
(词表维度)上提取input_ids
对应的 logits 值。 - 这样,每个 token 只保留它对应的 logits,而不是整个词表的所有 logits。
- 在
2. 代码示例与计算过程
2.1 示例:计算 Token Logits
import torch
# 假设 batch_size=2, sequence_length=3, vocab_size=5
logits = torch.tensor([
[[2.0, 1.0, 0.5, -1.0, 0.2],
[0.1, -0.5, 2.2, 1.5, 0.0],
[1.1, 3.5, 0.8, -0.2, -1.5]],
[[0.0, 2.3, -0.5, 1.0, 0.8],
[-1.2, 1.7, 2.0, 0.3, -0.8],
[2.5, -0.1, -1.2, 0.5, 3.0]]
])
input_ids = torch.tensor([
[0, 2, 1], # 第一个样本的 token 索引
[1, 3, 4] # 第二个样本的 token 索引
])
# 扩展维度,使 input_ids 形状变为 (batch_size, sequence_length, 1)
expanded_index = input_ids.unsqueeze(-1)
# 使用 gather 从 logits 中提取相应的 token logits
token_logits = logits.gather(dim=-1, index=expanded_index).squeeze(-1)
print(token_logits)
输出:
tensor([[2.0000, 2.2000, 3.5000],
[2.3000, 0.3000, 3.0000]])
2.2 gather(dim=-1)
计算过程解析
对于 logits.shape = (2, 3, 5)
:
dim=-1
代表最后一维,即 vocab_size=5 维度。input_ids.shape = (2, 3)
,表示每个 batch 的 token 在词表中的索引。
让我们手动解析 gather(dim=-1)
的计算步骤:
Batch | Token 索引 (dim=1 ) | Index (dim=-1 ) | Extracted Logit (gather 结果) |
---|---|---|---|
第 1 个样本 | 第 1 个 token | input_ids[0,0] = 0 | logits[0,0,0] = 2.0 |
第 2 个 token | input_ids[0,1] = 2 | logits[0,1,2] = 2.2 | |
第 3 个 token | input_ids[0,2] = 1 | logits[0,2,1] = 3.5 | |
第 2 个样本 | 第 1 个 token | input_ids[1,0] = 1 | logits[1,0,1] = 2.3 |
第 2 个 token | input_ids[1,1] = 3 | logits[1,1,3] = 0.3 | |
第 3 个 token | input_ids[1,2] = 4 | logits[1,2,4] = 3.0 |
这正是 gather(dim=-1)
提取的值。
3. dim=-1
vs. dim=-2
的区别
3.1 什么是 dim=-2
?
如果改成 gather(dim=-2)
,它会尝试在 sequence 维度(dim=1
) 上进行索引,这会导致错误的行为。
因为 input_ids
只包含 token 在词表中的索引,而不是 token 在句子中的索引。
3.2 如果错误地使用 dim=-2
wrong_gather = logits.gather(dim=-2, index=input_ids.unsqueeze(-1))
print(wrong_gather.shape)
dim=-2
代表第二个维度(sequence 维度)- 这意味着 PyTorch 会尝试从
logits
中选取整个 token 级别的数据,而不是单独的 token logits
❌ 结果错误,因为 input_ids
里的索引根本不适用于 dim=-2
!
3.3 为什么 dim=-1
才是正确的?
input_ids
里的索引指向 词表索引(vocab index),所以应该沿着词表维度(dim=-1
)索引数据。dim=-1
选择的是 单个 token 对应的 logits,不会影响整个句子结构。
4. 结论
dim 值 | 含义 | 是否正确 |
---|---|---|
dim=-1 (最后一维) | 提取每个 token 在词表中的 logits | ✅ 正确 |
dim=-2 (倒数第二维) | 尝试索引整个句子级别的数据 | ❌ 错误 |
核心要点
✅ dim=-1
(最后一维)用于 获取输入 token 对应的 logits 值,常用于 NLP 任务。
✅ gather(dim=-1, index=input_ids.unsqueeze(-1))
让 input_ids
选择 logits
里的正确位置。
❌ dim=-2
会错误地索引整个 token 级别的数据,而不是单个 token logits。
🚀 正确理解 gather(dim=-1)
,能够帮助你高效地提取模型输出,用于计算损失、评估模型! 🚀
Understanding torch.gather
: Purpose, Use Cases, and Implementation in NLP
In deep learning, particularly in Natural Language Processing (NLP) and reinforcement learning, we often need to extract specific values from high-dimensional tensors using given indices. torch.gather
is a powerful PyTorch function that efficiently retrieves data along a specified dimension based on an index tensor.
This article will cover:
- The purpose of
torch.gather
- How
gather
works with examples - Practical applications in NLP and deep learning
- Performance considerations and comparisons with
scatter
1. What is torch.gather
?
1.1 Basic Syntax
torch.gather(input, dim, index)
input
: The source tensor from which values are gathered.dim
: The dimension along which to index values.index
: A tensor containing the indices of elements to extract.
1.2 How gather
Works
- It retrieves values from
input
at positions specified byindex
along thedim
dimension. - The
index
tensor must have the same shape asinput
, except for thedim
dimension.
2. Basic Examples of torch.gather
2.1 Extracting Elements from a 2D Tensor
import torch
# Define a 3x4 tensor
input_tensor = torch.tensor([[10, 20, 30, 40],
[50, 60, 70, 80],
[90, 100, 110, 120]])
# Define the index tensor
index_tensor = torch.tensor([[0, 1],
[2, 3],
[1, 2]])
# Gather values along dimension 1 (columns)
output = torch.gather(input_tensor, dim=1, index=index_tensor)
print(output)
Output:
tensor([[ 10, 20],
[ 70, 80],
[100, 110]])
Explanation:
dim=1
means we are indexing columns.index_tensor[i, j]
determines which element to select frominput_tensor[i]
.
3. gather
in NLP: Extracting Token Logits
3.1 Why Use gather
in NLP?
In Transformer-based language models (GPT, BERT, etc.), we often need to compute the log probability of specific tokens. Given:
logits
: The model’s output scores for each token.input_ids
: The actual token indices.
We use gather
to efficiently retrieve the logits corresponding to each token.
3.2 Extracting Token Logits for Loss Calculation
import torch
# Simulated logits for batch_size=2, sequence_length=3, vocab_size=5
logits = torch.tensor([[[2.0, 1.0, 0.5, -1.0, 0.2],
[0.1, -0.5, 2.2, 1.5, 0.0],
[1.1, 3.5, 0.8, -0.2, -1.5]],
[[0.0, 2.3, -0.5, 1.0, 0.8],
[-1.2, 1.7, 2.0, 0.3, -0.8],
[2.5, -0.1, -1.2, 0.5, 3.0]]])
input_ids = torch.tensor([[0, 2, 1], # Token indices
[1, 3, 4]])
# Extracting logits corresponding to input_ids
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
print(token_logits)
Output:
tensor([[ 2.0000, 2.2000, 3.5000],
[ 2.3000, 0.3000, 3.0000]])
3.3 Explanation
input_ids.unsqueeze(-1)
converts shape(B, L)
→(B, L, 1)
, making it compatible withgather
.gather(dim=-1, index=input_ids.unsqueeze(-1))
retrieves the logits corresponding toinput_ids
.squeeze(-1)
removes the unnecessary last dimension.
This operation is efficient and memory-friendly compared to iterating through tokens manually.
4. gather
vs. scatter
While gather
retrieves values from an input tensor using an index, scatter
does the opposite: it writes values to an output tensor at specific indices.
4.1 Using scatter
to Modify a Tensor
import torch
# Initialize a 3x3 zero tensor
x = torch.zeros(3, 3)
# Define index positions
index = torch.tensor([[0, 2],
[1, 1],
[2, 0]])
# Define values to write
updates = torch.tensor([[5, 8],
[3, 7],
[6, 2]])
# Use scatter_ to update x
x.scatter_(dim=1, index=index, src=updates)
print(x)
Output:
tensor([[5., 0., 8.],
[0., 7., 0.],
[2., 0., 6.]])
4.2 Key Difference
gather
: Extracts values from specific indices.scatter
: Writes values to specific indices.
5. Performance and Memory Efficiency
5.1 Why gather
is Efficient?
- Vectorized indexing: Instead of looping through individual indices,
gather
efficiently extracts multiple values in parallel. - Lower memory footprint: Since
gather
does not require additional tensor allocations, it is more memory-efficient than manually indexing with loops. - Optimized for GPU: PyTorch internally optimizes
gather
to run efficiently on CUDA devices.
5.2 Performance Benchmark
import time
x = torch.randn(1000, 1000)
index = torch.randint(0, 1000, (1000, 500))
start = time.time()
_ = x.gather(dim=1, index=index)
end = time.time()
print(f"gather execution time: {end - start:.6f} s")
Results (Example):
gather execution time: 0.002341 s
This is much faster than manually iterating over indices.
6. Summary
torch.gather(input, dim, index)
efficiently extracts values from a tensor using an index tensor.- Common use cases:
- Extracting token logits for NLP tasks (e.g., loss computation in Transformer models).
- Indexing probability distributions in reinforcement learning.
- Selecting specific elements from multi-dimensional tensors.
gather
is memory-efficient, parallelized, and optimized for GPU acceleration.- Comparison with
scatter
:gather
extracts values from an input tensor.scatter
writes values into an output tensor.
🚀 Mastering torch.gather
will help you write more efficient deep learning models! 🚀
后记
2025年2月21日19点14分于上海,在GPT4o大模型辅助下完成。