当前位置: 首页 > article >正文

深度学习02-pytorch-05-张量的索引操作

在 PyTorch 中,张量的索引操作是非常灵活且强大的,允许你访问和修改张量中的特定元素或子集。与 NumPy 的数组操作类似,PyTorch 的张量支持多种索引方法,包括单纯的基础索引、高级索引和切片等。下面详细介绍 PyTorch 中常见的张量索引操作。

1. 基础索引

单个元素索引

你可以通过指定行、列来访问张量中的某个元素:

import torch
​
data = torch.tensor([[2, 8, 6, 9, 2],
                    [8, 7, 0, 8, 2],
                    [2, 7, 4, 0, 4],
                    [8, 9, 4, 3, 6]])
​
# 访问第0行,第1列的元素
print(data[0, 1])  # 输出 8
负索引

你可以使用负数索引来从末尾开始访问张量中的元素:

# 访问倒数第一行,倒数第二列的元素
print(data[-1, -2])  # 输出 3

2. 切片操作

你可以对张量的某个维度进行切片操作,获取其中的一部分数据。

切片示例
# 获取第0行到第2行(不包括第2行)和第1列到第3列(不包括第3列)的子张量
print(data[0:2, 1:3])
# 输出:
# tensor([[8, 6],
#         [7, 0]])
使用 : 选择所有行或列
# 获取所有行的第1列
print(data[:, 1])
# 输出 tensor([8, 7, 7, 9])

3. 布尔索引

布尔索引允许你通过条件来选择符合条件的元素。例如,选取所有大于 5 的元素:

mask = data > 5
print(mask)
# 输出:
# tensor([[False, True, True, True, False],
#         [ True, True, False, True, False],
#         [False, True, False, False, False],
#         [ True, True, False, False, True]])
​
# 选取大于 5 的所有元素
print(data[mask])
# 输出 tensor([8, 6, 9, 8, 7, 7, 8, 8, 9, 9, 6])

4. 高级索引

高级索引允许你通过张量、列表或其他数组来索引特定的位置。

多个索引

你可以使用两个列表/张量来同时索引行和列:

rows = [0, 1]
cols = [2, 3]
​
# 访问第 0 行第 2 列和第 1 行第 3 列的元素
print(data[rows, cols])
# 输出 tensor([6, 8])
结合切片与索引
# 获取第0行和第1行,并选取这些行的第2列和第3列的元素
print(data[[0, 1], 2:4])
# 输出:
# tensor([[6, 9],
#         [0, 8]])

5. None 和 ellipsis (...) 索引

None 索引
None 索引可以用于增加一个新的维度,类似于 NumPy 中的 np.newaxis:

data = torch.tensor([1, 2, 3])
expanded_data = data[:, None]
print(expanded_data)
# 输出:
# tensor([[1],
#         [2],
#         [3]])
...(省略号)操作符

... 操作符可以用于简化多维张量的索引,它会选择所有没有明确指定索引的维度。例如:

data = torch.randn(3, 4, 5)

# 使用省略号,访问所有行和列,并获取第2维的第0列
print(data[..., 0].shape)  # 输出形状是 (3, 4)

6. 张量赋值操作

通过索引操作,你还可以修改张量中的某些元素:

# 将第0行第1列的元素设为10
data[0, 1] = 10
print(data)

你也可以使用布尔索引或条件选择元素并赋值:

# 将所有大于5的元素设置为5
data[data > 5] = 5
print(data)

7. 高级操作:索引与广播

广播索引

当你使用多个不同形状的索引时,PyTorch 会自动应用广播机制,使索引适应不同维度的张量:

data = torch.tensor([[2, 8, 6, 9, 2],
                     [8, 7, 0, 8, 2],
                     [2, 7, 4, 0, 4],
                     [8, 9, 4, 3, 6]])

# 使用广播机制,选择第2行第2列、第2行第3列和第3行第2列、第3行第3列的元素
print(data[[[2], [3]], [2, 3]])
# 输出:
# tensor([[4, 0],
#         [4, 3]])

8. 张量选择操作

PyTorch 提供了 torch.index_select 等函数用于更加灵活的选择操作:

# 选择第 1 和第 2 行
indices = torch.tensor([1, 2])
selected_rows = torch.index_select(data, 0, indices)
print(selected_rows)

9. Gather 函数

torch.gather 函数允许你在指定的维度上根据索引选择张量中的元素:

data = torch.tensor([[2, 8, 6, 9, 2],
                    [8, 7, 0, 8, 2],
                    [2, 7, 4, 0, 4],
                    [8, 9, 4, 3, 6]])
​
index = torch.tensor([[0, 2], [3, 1], [1, 0], [2, 3]])
​
# gather 操作在第 1 维度上根据给定的 index 选择数据
result = torch.gather(data, 1, index)
print(result)
# 输出:
# tensor([[2, 6],
#         [8, 7],
#         [7, 2],
#         [4, 3]])

总结

PyTorch 提供了非常灵活和强大的张量索引机制。你可以使用基础索引、切片、布尔索引、以及高级索引来访问或修改张量中的数据。此外,PyTorch 的广播机制允许你对不同形状的张量进行操作,从而简化代码并提高计算效率。


http://www.kler.cn/news/315672.html

相关文章:

  • 2024 年最新前端ES-Module模块化、webpack打包工具详细教程(更新中)
  • Android 车载应用开发指南 - CarService 详解(下)
  • 在Spring Boot中实现多环境配置
  • 汽车总线之----FlexRay总线
  • LeetCode_sql_day31(1384.按年度列出销售总额)
  • C# 委托与事件 观察者模式
  • Java项目实战II基于Java+Spring Boot+MySQL的植物健康系统(开发文档+源码+数据库)
  • 设计模式之复合模式
  • 高级java每日一道面试题-2024年9月16日-框架篇-Spring MVC和Struts的区别是什么?
  • Redis发布和订阅
  • 波分技术基础 -- Liquid OTN技术特性
  • 高效打造知识图谱,使用LlamaIndex Relik实现实体关联和关系抽取
  • 火车站高铁站站点时刻查询网站计算机毕设/动车站点时刻查询
  • WebRTC编译后替换libwebrtc.aar时提示找不到libjingle_peerconnection_so.so库
  • 基于单片机控制的程控开关电源研究
  • list(一)
  • 基于微信小程序的健身房管理系统
  • ROS第五梯:ROS+VSCode+C++单步调试
  • [Golang] Context
  • GNU链接器(LD):设置入口点(ENTRY命令)的用法及实例解析
  • 科研绘图系列:R语言箱线图(boxplot)
  • error -- unsupported GNU version gcc later than 10 are not supported;(gcc、g++)
  • 计算机毕业设计 基于SpringBoot的小区运动中心预约管理系统的设计与实现 Java实战项目 附源码+文档+视频讲解
  • 【python】深度优先搜索文件夹并移动全部doc文件
  • 自闭症儿童寄宿学校:打造良好的学习和生活环境
  • 速盾:高防cdn除了快还有什么好处?
  • Maven国内镜像(四种)
  • 20240923 每日AI必读资讯
  • vue源码分析(九)—— 合并配置
  • ChromaDB教程_2024最新版(上)