当前位置: 首页 > article >正文

pytorch torch.gather函数介绍

torch.gather 是 PyTorch 中的一个用于从给定维度上按索引取值的函数。它根据一个索引张量 index,从源张量 input 中收集值,并返回一个新的张量。torch.gather 常用于需要从张量的特定位置抽取元素的操作。

1. 函数签名

torch.gather(input, dim, index, *, sparse_grad=False, out=None)
  • input:输入张量,表示要从中收集元素的源张量。
  • dim:要收集的维度索引。例如,对于一个二维张量,0 表示沿着行的维度,1 表示沿着列的维度。
  • index:索引张量,其形状应与input张量在除了dim维度之外的其他维度上保持一致。索引张量中的值表示在input张量对应维度上要收集的元素的索引。
  • out(可选):输出张量,如果提供,结果将存储在这个张量中。

2. 工作原理

torch.gather 在 dim 维度上,通过 index 指定的索引,从 input 中选取元素。 返回的张量的形状与 index 的形状相同。

3. 示例代码

以下是一个简单的示例代码,演示如何使用 torch.gather 函数:

import torch

# 创建一个源张量
input = torch.tensor([[1, 2, 3],
                      [4, 5, 6],
                      [7, 8, 9]])

# 创建一个索引张量
index = torch.tensor([[0, 2, 1],
                      [2, 0, 1],
                      [1, 2, 0]])

# 在 dim=1 维度上使用 gather 函数
result = torch.gather(input, dim=1, index=index)

print("Input Tensor:")
print(input)
print("\nIndex Tensor:")
print(index)
print("\nResult Tensor:")
print(result)

4. 输出结果

Input Tensor:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

Index Tensor:
tensor([[0, 2, 1],
        [2, 0, 1],
        [1, 2, 0]])

Result Tensor:
tensor([[1, 3, 2],
        [6, 4, 5],
        [8, 9, 7]])

5. 解释

  • 输入张量 (input) 是一个 3x3 的矩阵,每个元素代表一个值。
  • 索引张量 (index) 指定了要从 input 中提取的元素的索引。
  • 结果张量 (result) 是根据 index 从 input 中提取的元素形成的张量。

在这个例子中:

  • 对于 input 的第一行,index 提取了索引 0, 2, 1 对应的元素 1, 3, 2
  • 对于 input 的第二行,index 提取了索引 2, 0, 1 对应的元素 6, 4, 5
  • 对于 input 的第三行,index 提取了索引 1, 2, 0 对应的元素 8, 9, 7

6. 总结

  • torch.gather 通过索引在指定维度上提取张量中的元素,是用于基于索引选择数据的有用工具。
  • 函数对批处理数据特别有用,例如在分类任务中提取对应类别的概率或得分。
  • 索引张量的形状必须与源张量在指定维度的形状相匹配,以确保正确的取值操作。


http://www.kler.cn/news/294911.html

相关文章:

  • 运维工程师面试题--Linux加分项
  • Mysql(一) - 数据库操作, 表操作, CRUD
  • CMU 10423 Generative AI:lec3(阅读材料:GPT1论文解读)
  • 申万宏源证券完善金融服务最后一公里闭环,让金融服务“零距离、全天候”
  • 币安/欧易合约对冲APP系统开发
  • 【BuuCTF】BadySQli
  • C语言 | Leetcode C语言题解之第392题判断子序列
  • 小程序端pinia持久化
  • 2工作队列
  • 如何应对日益复杂的网络攻击?Edge SCDN(边缘安全加速)的应用场景探讨
  • 解决yarn安装依赖报错:certificate has expired at TLSSocket.onConnectSecure
  • 探索 MATLAB 中的 rem 函数:余数计算与应用
  • Find 方法、where 子句以及 AsNoTracking 方法各自有不同的用途和性能
  • 为libpng不同架构创建构建目录、编译、安装以及合并库文件的所有步骤。
  • python基础语法四-数据可视化
  • HTTP与HTTPS在软件测试中的解析
  • 使用modelsim小技巧
  • Mysql数据库表结构迁移PostgreSQL
  • springboot组件使用-mybatis组件使用
  • 《云原生安全攻防》-- K8s攻击案例:高权限Service Account接管集群
  • IPv6归属地查询-IPv6归属地接口-IPv6归属地离线库
  • 【有啥问啥】什么是扩散模型(Diffusion Models)?
  • [论文笔记] LLaVA
  • Effective Java学习笔记--39-41条 注解
  • 【LVI-SAM】激光雷达点云处理特征提取LIO-SAM 之FeatureExtraction实现细节
  • 把Django字典格式的数据库配置转成tortoise-orm的URL格式
  • k8s集群版部署
  • 排序算法-std::sort的使用(待学习第一天)
  • llama.cpp demo
  • 【H2O2|全栈】关于HTML(2)HTML基础(一)