当前位置: 首页 > article >正文

torch.stack 张量维度的变化

torch.stack 是 PyTorch 中用于将一系列张量沿一个新的维度堆叠的函数。与 torch.cat 不同的是,torch.stack会在指定的维度上增加一个新的维度,而不是将张量直接拼接。

基本用法

语法:

torch.stack(tensors, dim=0)
  • tensors: 一个张量列表,包含多个形状相同的张量(shape 必须相同)。
  • dim: 新增维度的位置,默认是 0

举例说明

假设有三个形状为 (2, 3) 的张量:

import torch

a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([[7, 8, 9], [10, 11, 12]])
c = torch.tensor([[13, 14, 15], [16, 17, 18]])

沿 dim=0 堆叠

stacked = torch.stack([a, b, c], dim=0)
print(stacked.shape)  # torch.Size([3, 2, 3])
  • 在维度 0 上增加一个新的维度,原始的 (2, 3) 形状变成 (3, 2, 3)
  • stacked 的第 0 维度有 3 个元素,对应原来的 abc 张量。

沿 dim=1 堆叠

stacked = torch.stack([a, b, c], dim=1)
print(stacked.shape)  # torch.Size([2, 3, 3])
  • 新的维度插入到原第 1 维的位置。
  • stacked 的第 1 维度有 3 个元素,对应原来的 abc 张量。

沿 dim=2 堆叠

stacked = torch.stack([a, b, c], dim=2)
print(stacked.shape)  # torch.Size([2, 3, 3])
  • 新的维度插入到原第 2 维的位置,形状变为 (2, 3, 3)

torch.stack 的形状变化总结

假设堆叠前的每个张量形状是 (A, B, C),在 dim=0dim=1 和 dim=2 堆叠后的形状分别为:

  • dim=0(N, A, B, C)
  • dim=1(A, N, B, C)
  • dim=2(A, B, N, C)

其中 N 是堆叠的张量数量。

和torch.cat函数的区别:

cat:在指定维度拼接多个张量。不增加维度。

c1 = torch.tensor([[1, 2], [3, 4]])
c2 = torch.tensor([[5, 6], [7, 8]])
c_cat = torch.cat([c1, c2], dim=0)  # shape (4, 2)


http://www.kler.cn/a/398235.html

相关文章:

  • 【项目开发】理解SSL延迟:为何HTTPS比HTTP慢?
  • cocosCreator视频web模式播放踩坑解决
  • 类和对象——拷贝构造函数,赋值运算符重载(C++)
  • SwanLab安装教程
  • Springboot集成ElasticSearch实现minio文件内容全文检索
  • MongoDB分布式集群搭建----副本集----PSS/PSA
  • 记录大学Linux运维上机考试题目和流程
  • 使用Python实现对接Hadoop集群(通过Hive)并提供API接口
  • STM32F103移植FreeRTOS
  • Scala-字符串(拼接、printf格式化输出等)-用法详解
  • Spring Boot编程训练系统:开发与部署
  • SpringBoot 创建对象常见的几种方式
  • UEFI学习(五)——启动框架
  • web-02
  • DB-GPT系列(六):数据Agent开发part1-光速创建AWEL Agent应用
  • Java 全栈知识体系
  • Oracle Instant Client 23.5安装配置完整教程
  • django框架-settings.py文件的配置说明
  • 【C语言】前端未来
  • 公开一下我的「个人学习视频」!
  • 【系统架构设计师】真题论文: 论基于 REST 服务的 Web 应用系统设计(包括解题思路和素材)
  • SQL面试题——日期交叉问题
  • PMP–一、二、三模、冲刺–分类–5.范围管理–技巧–引导
  • 三种网络模式固定IP
  • python关键字和内置函数有哪些?
  • AIGC学习笔记(5)——AI大模型开发工程师