当前位置: 首页 > article >正文

tensor连接和拆分

文章目录

    • 连接
      • torch.cat()
        • 案例准备
      • torch.stack()
        • 区别
    • 拆分
      • torch.split()

连接

torch.cat()

函数目的: 在给定维度上对输入的张量序列 进行连接操作。

案例准备
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float)
b = torch.tensor([[10,10,10,],[10,10,10],[10,10,10,]], dtype=torch.float)

在这里插入图片描述

# dim指的是维度,dim = 0就是行,所以下面的代码就是按行拼接
print("按行拼接:\n",torch.cat((a,b),dim=0))
print("按行拼接:\n",torch.cat((a,b),dim=0).shape) #6行3列

在这里插入图片描述

print("按列拼接:\n",torch.cat((a,b),dim=1))
print("按列拼接:\n",torch.cat((a,b),dim=1).shape)#3行6列

在这里插入图片描述

torch.stack()

沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
也就是2维拼成3维,3维拼4维,以此类推。

print("按行拼接:\n",torch.stack((a,b),dim=0))
print("按行拼接:\n",torch.stack((a,b),dim=0).shape) 

在这里插入图片描述

print("按行拼接:\n",torch.stack((a,b),dim=1))
print("按行拼接:\n",torch.stack((a,b),dim=1).shape)

在这里插入图片描述

print("按行拼接:\n",torch.stack((a,b),dim=2))
print("按行拼接:\n",torch.stack((a,b),dim=2).shape)

在这里插入图片描述

区别

stack与cat的区别在于,torch.stack()函数要求输入张量的大小完全相同,得到的张量的维度会比输入的张量的大小多1,并且多出的那个维度就是拼接的维度,那个维度的大小就是输入张量的个数。

c = torch.tensor([[10,20],[30,40],[50,60]], dtype=torch.float)
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float)
torch.cat((a,c),dim=1)

在这里插入图片描述

#但是以下情况就会出错
torch.cat((a,c),dim=0)

在这里插入图片描述
如图,按行拼接会缺数据,报错吗,应该的。
在这里插入图片描述

torch.stack((a,c),dim=0)
###运行结果
RuntimeError: stack expects each tensor to be equal size, but got [3, 3] at entry 0 and [3, 2] at entry 1

再次验证stack需要两个大小一样的张量

拆分

torch.split()

def split(
tensor: Tensor, split_size_or_sections: Union[int, List[int]], dim: int = 0
) -> Tuple[Tensor, …]:

  • 按块大小拆分张量 除不尽的取余数,返回一个元组
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float)
print(torch.split(a,2,dim=0))	#按行拆,两行拆成一个
print(torch.split(a,1,dim=0))	#按行拆,一行拆成一个
print(torch.split(a,1,dim=1))	#按列拆,一列拆成一个
print(torch.split(a,2,dim=1)) 	#按列拆,两列拆成一个

在这里插入图片描述

  • 按块数拆分张量
torch.chunk(a,2,dim=0)	#按行拆成两块
torch.split(a,2,dim=1)	#按列拆成两块

在这里插入图片描述


http://www.kler.cn/news/308672.html

相关文章:

  • mysql学习教程,从入门到精通,MySQL 子查询 子句(11)
  • 高级I/O知识分享【epoll || Reactor ET,LT模式】
  • Numba基础
  • vue.nextTick()方法的使用
  • python怎么运行cmd命令
  • 网络协议头分析
  • 【PostgreSQL-patroni维护命令】
  • 基于vue框架的宠物寄养系统3d388(程序+源码+数据库+调试部署+开发环境)系统界面在最后面。
  • USB开启ADB设置流程
  • 麒麟操作系统 MySQL 主从搭建
  • Qt QDialog点击界面自动激活问题解决办法
  • 枚举类题目练习心得
  • Golang | Leetcode Golang题解之第403题青蛙过河
  • 【题解】CF2009G1
  • QtC++截图支持获取鼠标光标
  • 运维工程师面试整理-虚拟化与容器
  • 实时数仓3.0DWD层
  • vulnhub(7):Toppo(经典的suid滥用提权)
  • ArcGIS Pro SDK (十四)地图探索 1 地图视图
  • 探索 InternLM 模型能力边界
  • 什么是外贸专用路由器?
  • 后端开发 每天六道面试题之打卡第一天
  • python中的各类比较与计算
  • Android14 蓝牙 BluetoothService 启动和相关代码介绍
  • 【Vue】- 生命周期和数据请求案例分析
  • phpstudy 建站使用 php8版本打开 phpMyAdmin后台出现网页提示致命错误:(phpMyAdmin这是版本问题导致的)
  • k8s中的存储
  • 【设计模式-外观】
  • 【计算机网络 - 基础问题】每日 3 题(七)
  • 【编译原理】看书笔记