当前位置: 首页 > article >正文

PyTorch reshape函数介绍

torch.reshape 是 PyTorch 用于改变张量形状的函数之一。它不会改变张量的数据,而是重新组织其元素以适应新的形状。


reshape 的使用

torch.reshape(input, shape) → Tensor
  • input:输入张量。
  • shape:新形状,使用整数或 -1 指定各维度大小。
    • -1 表示自动推断该维度大小,使总元素数保持不变。
示例
import torch

# 创建一个形状为 (2, 3) 的张量
x = torch.arange(6).view(2, 3)

# 使用 reshape 改变形状为 (3, 2)
y = torch.reshape(x, (3, 2))

print(y)
# 输出:
# tensor([[0, 1],
#         [2, 3],
#         [4, 5]])

使用 -1 自动推断

z = torch.reshape(x, (-1, 2))
print(z)
# 输出:
# tensor([[0, 1],
#         [2, 3],
#         [4, 5]])

与其他张量形状改变函数的区别

1. view
  • 特点view 也用于改变张量形状,但它要求输入张量在内存中是连续的。
  • 限制:如果张量不是连续的(即非 contiguous),使用 view 会报错,需要先调用 contiguous 方法。
  • 示例
x = torch.arange(6).view(2, 3)
y = x.view(3, 2)  # 可以直接使用

x = x.T  # 转置操作使张量变为非连续
y = x.view(3, 2)  # 会报错
2. permute
  • 特点:用于交换张量的维度,而不是改变形状。
  • 用途:适用于维度重新排列。
x = torch.rand(2, 3, 4)
y = x.permute(1, 0, 2)  # 改变维度顺序
3. resize_
  • 特点:修改张量形状,可能破坏原始数据,慎用。
  • 用途:多用于临时调整张量形状,不推荐在计算中使用。
4. squeeze / unsqueeze
  • 特点
    • squeeze:移除长度为 1 的维度。
    • unsqueeze:添加长度为 1 的维度。
  • 示例
x = torch.rand(1, 3, 1, 4)
y = x.squeeze()  # 去掉长度为 1 的维度
z = x.unsqueeze(2)  # 在第 2 个位置添加一个长度为 1 的维度
5. flatten
  • 特点:将多维张量展平为一维张量,或在指定维度范围内展平。
  • 用途:简化张量为线性输入。
  • 示例
    x = torch.rand(2, 3, 4)
    y = torch.flatten(x)  # 展平为 1D
    z = torch.flatten(x, start_dim=1)  # 从第 1 维开始展平
    print(z.shape)  # torch.Size([2, 12])

    reshape 的优势

  • 灵活性:不需要张量是连续的。
  • 安全性:自动处理非连续张量(相比 view)。
  • 性能:通常不会引入额外开销,尤其在连续内存情况下。
reshape 与 view 的选择
  • 如果确定张量是连续的,可用 view 提高性能。
  • 如果不确定张量是否连续,使用 reshape 更安全。

以下函数在改变张量形状或维度时不会破坏原始数据:

  • reshape
  • view(前提是张量连续)
  • permute
  • transpose
  • squeeze / unsqueeze
  • flatten
  • contiguous

这些操作只会影响数据的组织形式或内存布局,而不会修改数据本身。

总结

  • reshape 是 PyTorch 中改变张量形状的通用函数,灵活且易用。
  • 与其他形状操作函数(如 viewpermutesqueeze 等)的主要区别在于适用场景和对张量内存布局的要求。


http://www.kler.cn/a/501672.html

相关文章:

  • 任务调度系统Quartz.net详解2-Scheduler、Calendar及Listener
  • 微信小程序中 隐藏scroll-view 滚动条 网页中隐藏滚动条
  • 一 rk3568 Android 11固件开发环境搭建 (docker)
  • OpenCV基础:视频的采集、读取与录制
  • iOS 逆向学习 - Inter-Process Communication:进程间通信
  • Mysql 性能优化:覆盖索引
  • 使用Cilium/eBPF实现大规模云原生网络和安全
  • MongoDB 删除集合
  • nginx增加新模块
  • Python orjson ujson有什么区别?
  • 【DevOps】Jenkins使用Pipeline构建java代码
  • AIGC是什么?怎么用?简单三步ToDesk云电脑快速用
  • 前端学习-焦点事件以及键盘事件与典型案例(二十五)
  • Node.js——http 模块(二)
  • (Arxiv-2023)LORA-FA:针对大型语言模型微调的内存高效低秩自适应
  • 软件系统安全逆向分析-混淆对抗
  • HTML + CSS:如何强制div内容保持一行?
  • 26个开源Agent开发框架调研总结(2)
  • 如何使用高性能内存数据库Redis
  • 基于异步IO的io_uring
  • 【论文阅读+复现】High-fidelity Person-centric Subject-to-Image Synthesis
  • HAMi + prometheus-k8s + grafana实现vgpu虚拟化监控
  • 【Spring Boot 应用开发】-01 初识
  • 夯实前端基础之CSS篇
  • Edge浏览器内置的截长图功能
  • 品牌账号矩阵如何打造?来抄作业