当前位置: 首页 > article >正文

深度学习02-pytorch-07-张量的拼接操作

在 PyTorch 中,张量的拼接操作主要通过 torch.cat()torch.stack() 两个函数来完成。拼接操作允许你将多个张量沿着指定的维度连接在一起,构建更大的张量。以下是详细解释和举例说明:

1. torch.cat()

功能: 沿着指定的维度连接(拼接)多个张量。torch.cat() 是最常用的拼接函数,它不会增加新的维度,只是在指定维度上将张量的值连接在一起。

语法:

torch.cat(tensors, dim=0)
  • tensors: 需要拼接的张量列表。

  • dim: 沿着哪一个维度进行拼接。

示例:

import torch
# 创建两个形状相同的张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = torch.tensor([[7, 8, 9], [10, 11, 12]])
​
# 沿着第0个维度拼接
z0 = torch.cat((x, y), dim=0)
print(z0)

输出:

tensor([[ 1, 2, 3],
      [ 4, 5, 6],
      [ 7, 8, 9],
      [10, 11, 12]])

在这个例子中,xy 沿着第0维拼接(相当于竖直方向连接)。

# 沿着第1个维度拼接
z1 = torch.cat((x, y), dim=1)
print(z1)

输出:

tensor([[ 1, 2, 3, 7, 8, 9],
      [ 4, 5, 6, 10, 11, 12]])

在这个例子中,xy 沿着第1维拼接(相当于水平方向连接)。

2. torch.stack()

功能: 沿着新维度拼接多个张量。与 torch.cat() 不同,torch.stack() 会在指定维度插入一个新的维度,并将张量叠加在该维度上。

语法:

torch.stack(tensors, dim=0)
  • tensors: 需要叠加的张量列表。

  • dim: 在哪一个维度上插入新的维度。

示例:

import torch
# 创建两个相同形状的张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = torch.tensor([[7, 8, 9], [10, 11, 12]])
​
# 沿着新维度叠加张量
z = torch.stack((x, y), dim=0)
print(z)

输出:

tensor([[[ 1, 2, 3],
        [ 4, 5, 6]],
      [[ 7, 8, 9],
        [10, 11, 12]]])

在这个例子中,torch.stack() 插入了一个新的维度,最终的形状是 (2, 2, 3)

# 沿着第1个维度叠加
z1 = torch.stack((x, y), dim=1)
print(z1)

输出:

tensor([[[ 1, 2, 3],
        [ 7, 8, 9]],
      [[ 4, 5, 6],
        [10, 11, 12]]])

在这个例子中,新的维度插入到了第1维,最终的形状是 (2, 2, 3)

3. torch.chunk()

功能: 将一个张量沿着指定的维度分割成若干个小张量。 语法:

torch.chunk(tensor, chunks, dim=0)
  • tensor: 需要被分割的张量。

  • chunks: 分割成多少个张量。

  • dim: 沿着哪一个维度分割。

示例:

x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 沿着第0维分割为3块
chunks = torch.chunk(x, 3, dim=0)
for chunk in chunks:
   print(chunk)

输出:

tensor([[1, 2, 3]])
tensor([[4, 5, 6]])
tensor([[7, 8, 9]])

4. torch.split()

功能: 与 torch.chunk() 类似,但它允许指定每个子张量的大小。

语法:

torch.split(tensor, split_size_or_sections, dim=0)
  • tensor: 要分割的张量。

  • split_size_or_sections: 每个子张量的大小,或按指定的切片。

  • dim: 沿着哪一个维度分割。

示例:

x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 将张量分割为每块大小为2和1
splits = torch.split(x, [2, 1], dim=0)
for split in splits:
   print(split)

输出:

tensor([[1, 2, 3],
      [4, 5, 6]])
tensor([[7, 8, 9]])

5. torch.unbind()

功能: 沿着指定维度将张量解开为多个子张量。 语法:

torch.unbind(tensor, dim=0)
  • tensor: 要解开的张量。

  • dim: 沿着哪个维度解开。

示例:

x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 沿着第0维度解开
unbinded = torch.unbind(x, dim=0)
for t in unbinded:
   print(t)

输出:

tensor([1, 2, 3])
tensor([4, 5, 6])

在这个例子中,张量 x 沿着第0维被解开为两个子张量。

总结

  • torch.cat() 是最常用的拼接方法,用于沿着指定维度拼接多个张量。

  • torch.stack() 可以插入新的维度,叠加多个张量。

  • torch.chunk()torch.split() 用于将张量分割成多个子张量。

  • torch.unbind() 用于沿指定维度解开张量为多个张量。

这些操作允许你灵活地操作张量的维度,方便进行数据预处理和模型设计。


http://www.kler.cn/a/314333.html

相关文章:

  • python 2小时学会八股文-数据结构
  • Servlet入门 Servlet生命周期 Servlet体系结构
  • HTTP 客户端怎么向 Spring Cloud Sleuth 传输跟踪 ID
  • 新版 idea 编写 idea 插件时,启动出现 ClassNotFound
  • Python提取PDF和DOCX中的文本、图片和表格
  • 深入理解接口测试:实用指南与最佳实践5.0(三)
  • 剖析Spark Shuffle原理(图文详解)
  • go 以太坊代币查余额
  • Python | Leetcode Python题解之第424题替换后的最长重复字符
  • 是德科技Keysight N4433D ECal模块 26.5GHz 4端口3.5毫米
  • 在python爬虫中xpath方式提取lxml.etree._ElementUnicodeResult转化为字符串str类型
  • RAG+Agent人工智能平台:RAGflow实现GraphRA知识库问答,打造极致多模态问答与AI编排流体验
  • 演示jvm锁存在的问题
  • Java集合(三)
  • Centos7安装chrome的问题
  • WebApi开发中依赖注入和RESTful 详解
  • OceanBase 的并发简述笔记
  • Navicate 链接Oracle 提示 Oracle Library is not loaded ,账号密码都正确地址端口也对
  • 【变化检测】基于ChangeStar建筑物(LEVIR-CD)变化检测实战及ONNX推理
  • php变量赋值javascipt变量
  • 13.面试算法-字符串常见算法题(二)
  • 【论文阅读】3D Diffuser Actor: Policy Diffusion with 3D Scene Representations
  • 人工智能与机器学习原理精解【25】
  • 【电路笔记】-运算放大器积分器
  • 数模方法论-整数规划
  • Python类及元类的创建流程