当前位置: 首页 > article >正文

torch.gather和torch.take和torch.stack的等效替换

文章目录

  • dim = 1
  • dim = 0


dim = 1

import torch

# torch.gather(input, dim, index, out=None),dim = 1的情况。
input = torch.tensor([[1, 2, 3], 
                      [4, 5, 6]])
index = torch.tensor([[0, 1], 
                      [1, 2]])
output_gather = torch.gather(input, 1, index)
print('torch.gather result\n',output_gather)


# 1.使用索引和广播
rows = torch.arange(input.size(0)).unsqueeze(1).expand_as(index)
'''
torch.arange(input.size(0)) = tensor([0, 1])
torch.arange(input.size(0)).unsqueeze(1) = tensor([[0],
                                                [1]])
rows = tensor([[0, 0],
               [1, 1]])
'''
output_index = input[rows, index]
print('index result\n',output_gather)


# 2.使用使用 torch.take 和 torch.flatten

linear_index = index + torch.arange(input.size(0)).unsqueeze(1) * input.size(1)
'''
tensor([[0],  * 3 
        [1]])
        
linear_index = tensor([[0, 1],
                       [4, 5]])        
'''
output = torch.take(input, linear_index)
output_3 = output.view(index.shape)
print('take and flatten result\n',output_3)
# 3.使用 torch.index_select 和 torch.stack
output_4 = torch.stack([input[i, index[i]] for i in range(input.size(0))])
print('stack result\n',output_4)
import pdb;pdb.set_trace()

dim = 0

import torch


input = torch.tensor([[1, 2, 3], 
                      [4, 5, 6]])
index = torch.tensor([[0, 1], 
                      [1, 0]])
output_gather = torch.gather(input, 0, index)
print('torch.gather result\n',output_gather)



# index1 = torch.arange(input.shape[0])
# index2 = index1.expand_as(index)

# out = input[index, index2]
# print(out)
# import pdb;pdb.set_trace()

# num_cols = input.size(1)
# linear_index = index * num_cols + torch.arange(input.size(0)).unsqueeze(0)# 加纵坐标,从0开始
# output = torch.take(input, linear_index)
# print(output)

# import pdb;pdb.set_trace()  
a= [input[index[i],i] for i in range(input.size(0))]
output_4 = torch.stack([input[index[i],i] for i in range(input.size(0))],dim=1)
print('torch.stack result\n',output_4)


import pdb;pdb.set_trace()

http://www.kler.cn/a/376087.html

相关文章:

  • 数据结构与算--堆实现线段树
  • 基于springboot的自习室预订系统
  • 线程池底部工作原理
  • 【Rust练习】28.use and pub
  • 项目练习:若依管理系统字典功能-Vue前端部分
  • candb++ windows11运行报错,找不到mfc140.dll
  • AI-基本概念-多层感知器模型/CNN/RNN/自注意力模型
  • 操作符习题练习
  • C语言 | Leetcode C语言题解之第519题随机翻转矩阵
  • 金华迪加 现场大屏互动系统 mobile.do.php 任意文件上传漏洞复现
  • R 数据框
  • RabbitMQ 存储机制
  • 像`npm i`作为`npm install`的简写一样,使用`pdm i`作为`pdm install`的简写
  • ARM base instruction -- madd
  • 函数的多返回值及多种传参方式
  • Python 的基本语法
  • 【C#】异步和多线程
  • 速度!双击文件就可以运行本地大模型!神奇的AI大模型开源项目——llamafile
  • Redis中储存含LocalDateTime属性对象的序列化实现
  • R数据结构向量基础
  • 公有云开发基础教程
  • 汽车固态电池深度报告
  • 4K双模显示器值得买吗?
  • Python WordCloud库与jieba分词生成词云图的完整指南
  • Ollama:本地部署与运行大型语言模型的高效工具
  • 在kanzi 3.9.8里使用API创建自定义材质