当前位置: 首页 > article >正文

【PyTorch】5.张量索引操作

目录

1. 简单行、列索引

2. 列表索引

3. 范围索引

4. 布尔索引

5. 多维索引


个人主页:Icomi

在深度学习蓬勃发展的当下,PyTorch 是不可或缺的工具。它作为强大的深度学习框架,为构建和训练神经网络提供了高效且灵活的平台。神经网络作为人工智能的核心技术,能够处理复杂的数据模式。通过 PyTorch,我们可以轻松搭建各类神经网络模型,实现从基础到高级的人工智能应用。接下来,就让我们一同走进 PyTorch 的世界,探索神经网络与人工智能的奥秘。本系列为PyTorch入门文章,若各位大佬想持续跟进,欢迎与我交流互关。

大家好,我是一颗米,前面咱们已经深入学习了张量的拼接操作,这让我们在搭建神经网络时能够更灵活地组合数据,就像掌握了一套精妙的 “组合拳”,让模型构建更加得心应手。然而,在我们使用张量的过程中,还有一项关键技能,如同在茂密的数据丛林中开辟道路的利器,那就是张量的花式索引操作。

当我们在操作张量时,就好比在管理一个庞大而复杂的仓库,里面存放着各种各样的数据。经常会遇到这样的情况,我们需要从这个 “数据仓库” 里精准地获取特定的数据,或者对某些数据进行修改。这时候,张量的花式索引操作就派上用场了。

想象一下,你置身于一个堆满了各种规格货物(数据)的巨大仓库,普通的索引方式就像只能按顺序依次寻找货物,效率低下。而花式索引则像是给了你一把神奇的钥匙,能够让你直接定位到那些分散在不同位置的特定货物,快速地获取或者修改它们。

掌握张量的花式索引操作,是我们必须具备的一项能力。它就像一把万能钥匙,能让我们在处理张量数据时更加高效、精准。无论是构建复杂的神经网络模型,还是对模型输出的数据进行后处理,花式索引都能帮助我们更灵活地操作张量。

接下来,我们就一起深入探索张量的花式索引操作,看看这把神奇的 “钥匙” 究竟有哪些神奇的用法,如何帮助我们在数据的海洋中畅行无阻。

1. 简单行、列索引

准备数据

import torch

data = torch.randint(0, 10, [4, 5])
print(data)
print('-' * 50)

结果

tensor([[0, 7, 6, 5, 9],
        [6, 8, 3, 1, 0],
        [6, 3, 8, 7, 3],
        [4, 9, 5, 3, 1]])
--------------------------------------------------
# 1. 简单行、列索引
def test01():

    print(data[0])
    print(data[:, 0])
    print('-' * 50)

if __name__ == '__main__':
    test01()

程序输出结果

tensor([0, 7, 6, 5, 9])
tensor([0, 6, 6, 4])
--------------------------------------------------

2. 列表索引

# 2. 列表索引
def test02():

    # 返回 (0, 1)、(1, 2) 两个位置的元素
    print(data[[0, 1], [1, 2]])
    print('-' * 50)

    # 返回 0、1 行的 1、2 列共4个元素
    print(data[[[0], [1]], [1, 2]])
if __name__ == '__main__':
    test02()

输出结果

tensor([7, 3])
--------------------------------------------------
tensor([[7, 6],
        [8, 3]])

3. 范围索引

# 3. 范围索引
def test03():
    # 前3行的前2列数据
    print(data[:3, :2])
    # 第2行到最后的前2列数据
    print(data[2:, :2])
if __name__ == '__main__':
    test03()

结果

tensor([[0, 7],
        [6, 8],
        [6, 3]])
tensor([[6, 3],
        [4, 9]])

4. 布尔索引

# 布尔索引
def test():

    # 第三列大于5的行数据
    print(data[data[:, 2] > 5])
    # 第二行大于5的列数据
    print(data[:, data[1] > 5])
if __name__ == '__main__':
    test04()

输出结果

tensor([[0, 7, 6, 5, 9],
        [6, 3, 8, 7, 3]])
tensor([[0, 7],
        [6, 8],
        [6, 3],
        [4, 9]])

5. 多维索引

# 多维索引
def test05():

    data = torch.randint(0, 10, [3, 4, 5])
    print(data)
    print('-' * 50)

    print(data[0, :, :])
    print(data[:, 0, :])
    print(data[:, :, 0])


if __name__ == '__main__':
    test05()

输出结果

tensor([[[2, 4, 1, 2, 3],
         [5, 5, 1, 5, 0],
         [1, 4, 5, 3, 8],
         [7, 1, 1, 9, 9]],

        [[9, 7, 5, 3, 1],
         [8, 8, 6, 0, 1],
         [6, 9, 0, 2, 1],
         [9, 7, 0, 4, 0]],

        [[0, 7, 3, 5, 6],
         [2, 4, 6, 4, 3],
         [2, 0, 3, 7, 9],
         [9, 6, 4, 4, 4]]])
--------------------------------------------------
tensor([[2, 4, 1, 2, 3],
        [5, 5, 1, 5, 0],
        [1, 4, 5, 3, 8],
        [7, 1, 1, 9, 9]])
tensor([[2, 4, 1, 2, 3],
        [9, 7, 5, 3, 1],
        [0, 7, 3, 5, 6]])
tensor([[2, 5, 1, 7],
        [9, 8, 6, 9],
        [0, 2, 2, 9]])


http://www.kler.cn/a/522129.html

相关文章:

  • STM32 TIM输入捕获 测量频率
  • 【后端】Flask
  • 计算机网络 IP 网络层 2 (重置版)
  • 【HarmonyOS之旅】基于ArkTS开发(三) -> 兼容JS的类Web开发(三)
  • 消息队列篇--通信协议篇--应用层协议和传输层协议理解
  • 团体程序设计天梯赛-练习集——L1-025 正整数A+B
  • UE求职Demo开发日志#13 完善所有伤害判定
  • 【ESP32】ESP-IDF开发 | WiFi开发 | UDP用户数据报协议 + UDP客户端和服务器例程
  • AI软件外包需要注意什么 外包开发AI软件的关键因素是什么 如何选择AI外包开发语言
  • 2025年寒假ACM训练赛1
  • Tailwind CSS 正式发布了 4.0 版本
  • 跨平台物联网漏洞挖掘算法评估框架设计与实现项目经费使用记录和参考文献
  • 除了layui.js还有什么比较好的纯JS组件WEB UI?在谷歌浏览上显示
  • 【2025年最新版】Java JDK安装、环境配置教程 (图文非常详细)
  • java构建工具之Gradle
  • 【AI论文】FilmAgent: 一个用于虚拟3D空间中端到端电影制作自动化的多智能体框架
  • 常用的npm镜像源配置方法
  • 刀客doc:禁令影响下,TikTok广告业务正在被对手截胡
  • JVM垃圾回收器的原理和调优详解!
  • 深圳大学-智能网络与计算-实验四:云-边协同计算实验
  • Java多线程与高并发专题——保障可见性和有序性
  • 分布式组件底层逻辑是什么?
  • 在计算机上本地运行 Deepseek R1
  • Couchbase UI: Bucket
  • 小程序 uniapp 地图 自定义内容呈现,获取中心点,获取对角经纬度,首次获取对角经纬度
  • 蓝桥村打花结的花纸选择问题