当前位置: 首页 > article >正文

pytorch 笔记:index_select

1 基本使用方法

index_select 是 PyTorch 中的一个非常有用的函数,允许从给定的维度中选择指定索引的张量值

torch.index_select(input, dim, index, out=None) -> Tensor
input从中选择数据的源张量
dim从中选择数据的维度
index

一个 1D 张量,包含你想要从 dim 维度中选择的索引

此张量应该是 LongTensor 类型

out

一个可选的参数,用于指定输出张量。

如果没有提供,将创建一个新的张量。

2 举例

import torch
import numpy as np

x = torch.tensor(np.arange(16).reshape(4,4))
index=torch.LongTensor([1,3])
x
'''
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]], dtype=torch.int32)
'''

torch.index_select(x,dim=0,index=index)
'''
tensor([[ 4,  5,  6,  7],
        [12, 13, 14, 15]], dtype=torch.int32)
'''

torch.index_select(x,dim=1,index=index)
'''
tensor([[ 1,  3],
        [ 5,  7],
        [ 9, 11],
        [13, 15]], dtype=torch.int32)
'''

3 index_select保存梯度

import torch
import numpy as np

x = torch.tensor(np.arange(16).reshape(4,4),dtype=torch.float32, requires_grad=True)
index=torch.LongTensor([1,3])
x
'''
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]], requires_grad=True)
'''

torch.index_select(x,dim=0,index=index)
'''
tensor([[ 4.,  5.,  6.,  7.],
        [12., 13., 14., 15.]], grad_fn=<IndexSelectBackward0>)
'''

torch.index_select(x,dim=1,index=index)
'''
tensor([[ 1.,  3.],
        [ 5.,  7.],
        [ 9., 11.],
        [13., 15.]], grad_fn=<IndexSelectBackward0>)
'''


http://www.kler.cn/news/108090.html

相关文章:

  • Java面试八股文之暑假合集
  • Seata入门系列【15】@GlobalLock注解使用场景及源码分析
  • 面试经典150题——Day24
  • React Router初学者入门指南(2023版)
  • Pytorch代码入门学习之分类任务(三):定义损失函数与优化器
  • 【Qt】绘图与绘图设备
  • C++不能在子类中构造函数的初始化成员列表中直接初始化基类成员变量
  • C++ 运算符
  • Linux touch命令:创建文件及修改文件时间
  • 底层驱动day8作业
  • 【C++】智能指针:auto_ptr、unique_ptr、share_ptr、weak_ptr(技术介绍 + 代码实现)(待更新)
  • Megatron-LM GPT 源码分析(三) Pipeline Parallel分析
  • AWS SAP-C02教程11-解决方案
  • C#,数值计算——分类与推理,基座向量机的 Svmgenkernel的计算方法与源程序
  • 中微爱芯74逻辑兼容替代TI/ON/NXP工规品质型号全
  • 【杂记】Ubuntu20.04装系统,安装CUDA等
  • python爬虫之feapder.AirSpider轻量爬虫案例:豆瓣
  • PHP简单实现预定义钩子和自定义钩子
  • Linux国产系统无法连接身份证读卡器USB权限解决办法
  • nrf52832 开发板入手笔记:J-Flash 蓝牙协议栈烧写
  • Nginx 的配置文件(负载均衡,反向代理)
  • Spring Security: 整体架构
  • uniapp-图片压缩(适配H5,APP)
  • 10月Java行情 回暖?
  • 【机器学习可解释性】4.SHAP 值
  • 第10期 | GPTSecurity周报
  • scratch接钻石 2023年9月中国电子学会图形化编程 少儿编程 scratch编程等级考试三级真题和答案解析
  • 力扣第763题 划分字母区间 c++ 哈希 + 双指针 + 小小贪心
  • 制作自己的前端组件库并上传到npm上
  • MySQL实战2