当前位置: 首页 > article >正文

pytorch索引操作函数介绍

PyTorch 提供了一系列强大的索引操作函数,用于高效地操作张量数据。以下是常用的 PyTorch 索引操作函数及其用途:


1. 基础索引操作

这些操作类似于 NumPy 的基本索引方式。

1.1 方括号索引 ([])

  • 支持 切片布尔索引高级索引
  • 示例:
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
print(x[0, 1])  # 取第 0 行第 1 列,输出 20
print(x[:, 1])  # 取所有行的第 1 列,输出 tensor([20, 50])
print(x[1])     # 取第 1 行,输出 tensor([40, 50, 60])

2. 索引操作函数

2.1 torch.index_select

按指定维度的索引提取元素。

  • 用法
torch.index_select(input, dim, index, *, out=None)
  • 参数
    • input:输入张量。
    • dim:指定的维度。
    • index:索引张量,必须是一维张量。
  • 示例
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
index = torch.tensor([0, 2])
result = torch.index_select(x, dim=1, index=index)
print(result)  # 输出 tensor([[10, 30], [40, 60]])

2.2 torch.gather

根据索引张量,从输入张量中收集数据。

  • 用法
torch.gather(input, dim, index, *, sparse_grad=False, out=None)
  • 参数
    • input:输入张量。
    • dim:指定的维度。
    • index:索引张量,必须与输入张量形状兼容。
  • 示例
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
index = torch.tensor([[0, 2], [1, 0]])
result = torch.gather(x, dim=1, index=index)
print(result)  # 输出 tensor([[10, 30], [50, 40]])

2.3 torch.scatter

将源张量的值按索引写入目标张量。

  • 用法
torch.scatter(input, dim, index, src, *, reduce=None)

示例

x = torch.zeros(2, 3)
index = torch.tensor([[0, 1, 2], [0, 1, 2]])
src = torch.tensor([[10., 20., 30.], [40., 50., 60.]])
result = torch.scatter(x, dim=1, index=index, src=src)
print(result)  # 输出 tensor([[10., 20., 30.], [40., 50., 60.]])

2.4 torch.take

按照扁平化索引从张量中提取元素。

  • 用法
torch.take(input, index)

示例

x = torch.tensor([[10, 20, 30], [40, 50, 60]])
index = torch.tensor([0, 2, 5])
result = torch.take(x, index)
print(result)  # 输出 tensor([10, 30, 60])

torch.gather和 torch.scatter区别

特性torch.gathertorch.scatter
作用从指定位置提取数据向指定位置写入数据
索引意义定义要从 input 中提取数据的位置定义要向 input 中写入数据的位置
输出形状与 index 形状相同与 input 形状相同
操作方向从 input 提取值将 src 的值写入到 input
使用场景数据提取(如选择性采样)数据写入(如更新某些特定位置的值)

2.5 torch.masked_select

根据布尔掩码从张量中提取元素。

  • 用法
torch.masked_select(input, mask, *, out=None)

示例

x = torch.tensor([[10, 20, 30], [40, 50, 60]])
mask = x > 30
result = torch.masked_select(x, mask)
print(result)  # 输出 tensor([40, 50, 60])

2.6 torch.nonzero

返回所有非零元素的索引。

  • 用法
torch.nonzero(input, *, as_tuple=False)

示例

x = torch.tensor([[0, 1, 0], [2, 0, 3]])
result = torch.nonzero(x)
print(result)  # 输出 tensor([[0, 1], [1, 0], [1, 2]])

2.7 torch.where

根据条件选择值。

  • 用法
torch.where(condition, x, y)

示例

x = torch.tensor([1, 2, 3])
y = torch.tensor([10, 20, 30])
condition = x > 1
result = torch.where(condition, x, y)
print(result)  # 输出 tensor([10,  2,  3])

2.8 torch.advanced_indexing

  • PyTorch 支持 布尔张量索引 和 整形张量索引
  • 示例:
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
index = torch.tensor([[0, 1], [1, 0]])
result = x[index]
print(result)

输出:

tensor([[[10, 20, 30],
         [40, 50, 60]],

        [[40, 50, 60],
         [10, 20, 30]]])

注:第0维度操作

3. 总结

函数功能
torch.index_select按指定维度和索引提取元素
torch.gather根据索引张量从输入中收集元素
torch.scatter按索引将源张量的值写入目标张量
torch.take按扁平化索引从张量中提取元素
torch.masked_select根据布尔掩码从张量中提取元素
torch.nonzero找出非零元素的索引
torch.where根据条件选择元素

这些索引操作函数可以帮助简化张量操作,提高代码效率。


http://www.kler.cn/a/468915.html

相关文章:

  • shell脚本总结2
  • 力扣23.合并K个升序链表
  • 解决openpyxl操纵带公式的excel或者csv之后,pandas无法读取数值的问题
  • 《学校一卡通管理系统》数据库MySQL的设计与实现
  • Java中如何实现线程安全的单例模式?
  • Trimble天宝X9三维扫描仪为建筑外墙检测提供了全新的解决方案【沪敖3D】
  • DOM HTML
  • 【Vim Masterclass 笔记05】第 4 章:Vim 的帮助系统与同步练习(L14+L15+L16)
  • 银行账户类别详解
  • 【Springboot知识】Springboot监控工具SpringbootAdmin
  • 游泳溺水识别数据集,对25729张图片进行YOLO,COCO JSON, VOC XML 格式的标注,溺水平均识别率在89.9%
  • 数据结构复习 (顺序查找,对半查找,斐波那契查找,插值查找,分块查找)
  • 鸿蒙UI开发——Toast即时提示框的使用
  • 【Qt】QLabel显示图片
  • 【STM32项目】智能物联网驱动的生物样本培育与管理辅助系统(完整工程资料源码)
  • 低空管控技术-无人机云监视技术详解
  • 功能篇:页面实现实时的时钟功能
  • 【NLP高频面题 - Transformer篇】Transformer的输入中为什么要添加位置编码?
  • java基础之代理
  • Qt 绘图
  • 9.课程分类查询
  • Linux(Centos 7.6)命令行快捷键
  • I.MX6ull-PWM
  • 封装/前线修饰符/Idea项目结构/package/impore
  • Linux菜鸟级常用的基本指令和基础知识
  • Vue.js组件开发-Provide/Inject的使用及高级应用