pytorch torch.scatter_reduce函数介绍
PyTorch torch.scatter_reduce
函数
torch.scatter_reduce
是 PyTorch 中的一种高级操作,用于在特定维度上将源张量的值按索引归约到目标张量中。它结合了 scatter 和 reduce 操作,非常适合处理需要对特定索引进行归约(如求和、最大值等)的场景。
函数签名
torch.scatter_reduce(input, dim, index, src, reduce, *, include_self=True, out=None)
参数说明
-
input
:- 目标张量,表示归约操作的初始值。
-
dim
:- 指定在目标张量
input
中进行归约操作的维度。
- 指定在目标张量
-
index
:- 张量,表示目标张量中归约操作的索引位置。
index
的形状必须与src
兼容,或者可以广播成src
的形状。
-
src
:- 源张量,提供要归约到
input
中的值。
- 源张量,提供要归约到
-
reduce
:- 指定归约操作的类型,支持以下选项:
"sum"
:按索引进行求和。"prod"
:按索引进行乘积。"mean"
:按索引计算平均值。"amax"
:按索引取最大值。"amin"
:按索引取最小值。
- 指定归约操作的类型,支持以下选项:
-
include_self
(可选, 默认True
):- 是否在归约时包括
input
中的原始值。 - 如果为
False
,只使用src
中的值进行归约。
- 是否在归约时包括
-
out
(可选):- 用于存储结果的张量。如果提供,将直接修改此张量。
返回值
返回一个张量,包含归约操作的结果,形状与 input
相同。
示例
1. 按索引求和 (reduce="sum"
)
import torch
input = torch.zeros(3, 5)
index = torch.tensor([[0, 1, 2, 0, 1],
[1, 2, 0, 1, 2]])
src = torch.tensor([[10., 20., 30., 40., 50.],
[1., 2., 3., 4., 5.]])
result = torch.scatter_reduce(input, dim=1, index=index, src=src, reduce="sum")
print(result)
输出:
tensor([[50., 70., 30., 0., 0.],
[ 3., 5., 7., 0., 0.],
[ 0., 0., 0., 0., 0.]])
2. 按索引取最大值 (reduce="amax"
)
result = torch.scatter_reduce(input, dim=1, index=index, src=src, reduce="amax")
print(result)
输出:
tensor([[40., 50., 30., 0., 0.],
[ 3., 4., 5., 0., 0.],
[ 0., 0., 0., 0., 0.]])
3. 使用 include_self=False
result = torch.scatter_reduce(input, dim=1, index=index, src=src, reduce="sum", include_self=False)
print(result)
输出:
tensor([[50., 70., 30., 0., 0.],
[ 3., 5., 7., 0., 0.],
[ 0., 0., 0., 0., 0.]])
注意事项
-
index
范围:index
的值必须在[0, input.shape[dim])
范围内,否则会引发错误。
-
广播规则:
index
和src
必须具有相同的形状,或者可以通过广播匹配。
-
性能优化:
torch.scatter_reduce
对于稀疏更新和归约非常高效,避免了循环操作。
应用场景
- 聚合数据(如按索引分组求和或求最大值)。
- 构造稀疏张量。
- 实现自定义的归约操作(如图神经网络中的消息传递)。