当前位置: 首页 > article >正文

pytorch 笔记:index_select

1 基本使用方法

index_select 是 PyTorch 中的一个非常有用的函数,允许从给定的维度中选择指定索引的张量值

torch.index_select(input, dim, index, out=None) -> Tensor
input从中选择数据的源张量
dim从中选择数据的维度
index

一个 1D 张量,包含你想要从 dim 维度中选择的索引

此张量应该是 LongTensor 类型

out

一个可选的参数,用于指定输出张量。

如果没有提供,将创建一个新的张量。

2 举例

import torch
import numpy as np

x = torch.tensor(np.arange(16).reshape(4,4))
index=torch.LongTensor([1,3])
x
'''
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]], dtype=torch.int32)
'''

torch.index_select(x,dim=0,index=index)
'''
tensor([[ 4,  5,  6,  7],
        [12, 13, 14, 15]], dtype=torch.int32)
'''

torch.index_select(x,dim=1,index=index)
'''
tensor([[ 1,  3],
        [ 5,  7],
        [ 9, 11],
        [13, 15]], dtype=torch.int32)
'''

3 index_select保存梯度

import torch
import numpy as np

x = torch.tensor(np.arange(16).reshape(4,4),dtype=torch.float32, requires_grad=True)
index=torch.LongTensor([1,3])
x
'''
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]], requires_grad=True)
'''

torch.index_select(x,dim=0,index=index)
'''
tensor([[ 4.,  5.,  6.,  7.],
        [12., 13., 14., 15.]], grad_fn=<IndexSelectBackward0>)
'''

torch.index_select(x,dim=1,index=index)
'''
tensor([[ 1.,  3.],
        [ 5.,  7.],
        [ 9., 11.],
        [13., 15.]], grad_fn=<IndexSelectBackward0>)
'''


http://www.kler.cn/a/108090.html

相关文章:

  • Autosar CP DDS规范导读
  • 安全生产管理的重要性:现状、痛点与改进之路
  • pycharm快速更换虚拟环境
  • 【MATLAB代码】二维平面上的TDOA,使用加权最小二乘法,不限制锚点数量,代码可复制粘贴
  • layui 文件上传前检查文件大小,后面再点上传出现重复提交的问题
  • 微服务容器化部署实践(FontConfiguration.getVersion)
  • Java面试八股文之暑假合集
  • Seata入门系列【15】@GlobalLock注解使用场景及源码分析
  • 面试经典150题——Day24
  • React Router初学者入门指南(2023版)
  • Pytorch代码入门学习之分类任务(三):定义损失函数与优化器
  • 【Qt】绘图与绘图设备
  • C++不能在子类中构造函数的初始化成员列表中直接初始化基类成员变量
  • C++ 运算符
  • Linux touch命令:创建文件及修改文件时间
  • 底层驱动day8作业
  • 【C++】智能指针:auto_ptr、unique_ptr、share_ptr、weak_ptr(技术介绍 + 代码实现)(待更新)
  • Megatron-LM GPT 源码分析(三) Pipeline Parallel分析
  • AWS SAP-C02教程11-解决方案
  • C#,数值计算——分类与推理,基座向量机的 Svmgenkernel的计算方法与源程序
  • 中微爱芯74逻辑兼容替代TI/ON/NXP工规品质型号全
  • 【杂记】Ubuntu20.04装系统,安装CUDA等
  • python爬虫之feapder.AirSpider轻量爬虫案例:豆瓣
  • PHP简单实现预定义钩子和自定义钩子
  • Linux国产系统无法连接身份证读卡器USB权限解决办法
  • nrf52832 开发板入手笔记:J-Flash 蓝牙协议栈烧写