当前位置: 首页 > article >正文

PyTorch维度操作的函数介绍

在 PyTorch 中,操作张量的维度是常见的需求,特别是在处理多维数据时。PyTorch 提供了一系列函数来操作张量的维度,包括改变维度顺序、添加或删除维度、扩展维度等。下面是一些常用的维度操作函数及其示例代码。

1. view()

  • 作用:重新调整张量的形状(维度),但不改变其数据内容。view() 是基于张量的原始内存布局进行操作的,要求重新调整的形状能与原始数据兼容。
  • 示例
import torch

# 创建一个形状为 [2, 3, 4] 的张量
tensor = torch.randn(2, 3, 4)

# 调整为形状为 [6, 4] 的张量
reshaped = tensor.view(6, 4)
print(reshaped.shape)  # 输出: torch.Size([6, 4])

2. permute()

  • 作用:重新排列张量的维度顺序。
  • 示例
import torch

# 创建一个形状为 [2, 3, 4] 的张量
tensor = torch.randn(2, 3, 4)

# 交换第一个维度和第二个维度,得到形状为 [3, 2, 4] 的张量
permuted = tensor.permute(1, 0, 2)
print(permuted.shape)  # 输出: torch.Size([3, 2, 4])

3. unsqueeze()

  • 作用:在指定位置插入一个大小为 1 的新维度。
  • 示例
import torch

# 创建一个形状为 [3, 4] 的张量
tensor = torch.randn(3, 4)

# 在第 0 维添加一个新维度,结果形状为 [1, 3, 4]
unsqueezed = tensor.unsqueeze(0)
print(unsqueezed.shape)  # 输出: torch.Size([1, 3, 4])

4. squeeze()

  • 作用:移除张量中所有大小为 1 的维度。
  • 示例
import torch

# 创建一个形状为 [1, 3, 1, 4] 的张量
tensor = torch.randn(1, 3, 1, 4)

# 移除所有大小为 1 的维度,结果形状为 [3, 4]
squeezed = tensor.squeeze()
print(squeezed.shape)  # 输出: torch.Size([3, 4])

5. transpose()

  • 作用:交换张量的两个指定维度。
  • 示例
import torch

# 创建一个形状为 [2, 3, 4] 的张量
tensor = torch.randn(2, 3, 4)

# 交换第 1 维和第 2 维,结果形状为 [2, 4, 3]
transposed = tensor.transpose(1, 2)
print(transposed.shape)  # 输出: torch.Size([2, 4, 3])

6. expand()

  • 作用:将张量的某些维度扩展为更大的尺寸,不会复制数据,而是通过广播机制扩展。
  • 示例
import torch

# 创建一个形状为 [2, 1, 4] 的张量
tensor = torch.randn(2, 1, 4)

# 扩展第 1 维到大小为 3,结果形状为 [2, 3, 4]
expanded = tensor.expand(2, 3, 4)
print(expanded.shape)  # 输出: torch.Size([2, 3, 4])

7. repeat()

  • 作用:沿着指定的维度重复张量的元素。
  • 示例
import torch

# 创建一个形状为 [2, 3] 的张量
tensor = torch.randn(2, 3)

# 沿着第 0 维和第 1 维分别重复 2 次和 3 次,结果形状为 [4, 9]
repeated = tensor.repeat(2, 3)
print(repeated.shape)  # 输出: torch.Size([4, 9])

8. cat()

  • 作用:在指定维度上连接多个张量。
  • 示例
import torch

# 创建两个形状为 [2, 3] 的张量
tensor1 = torch.randn(2, 3)
tensor2 = torch.randn(2, 3)

# 在第 0 维连接,结果形状为 [4, 3]
concatenated = torch.cat([tensor1, tensor2], dim=0)
print(concatenated.shape)  # 输出: torch.Size([4, 3])

9. stack()

  • 作用:在新的维度上堆叠多个张量。
  • 示例
import torch

# 创建两个形状为 [2, 3] 的张量
tensor1 = torch.randn(2, 3)
tensor2 = torch.randn(2, 3)

# 在新的第 0 维堆叠,结果形状为 [2, 2, 3]
stacked = torch.stack([tensor1, tensor2], dim=0)
print(stacked.shape)  # 输出: torch.Size([2, 2, 3])

总结

PyTorch 提供了丰富的维度操作函数,使得张量的操作非常灵活。在处理多维数据时,合理使用这些函数可以极大地简化代码,并提高数据处理的效率。


http://www.kler.cn/a/297497.html

相关文章:

  • linux高级学习12
  • 运维学习————Zabbix监控框架(1)
  • 高级算法设计与分析 学习笔记3 哈希表
  • LaTeX中算法环境横线/宽度调整(Algorithm)
  • 收银系统源码-收银台(exe、apk安装包)自由灵活操作简单!
  • 【阿雄不会写代码】全国职业院校技能大赛GZ036第五套
  • HTTP1.0 到 HTTP3.0 的优化
  • 【网络安全 | 渗透工具】IIS 短文件名枚举工具—shortscan安装使用教程
  • @Transactional 参数详解
  • Charles - 夜神模拟器证书安装App抓包-charles监控手机出现unknown 已解决
  • 子网ip和ip地址一样吗?子网ip地址怎么算
  • Google AI 概述——喜欢的三点和不喜欢的两点
  • 使用Python海龟绘图画出奥运五环图
  • Android消息类型及事件分发流程
  • 99.WEB渗透测试-信息收集-网络空间搜索引擎shodan(1)
  • 神经网络的线性部分和非线性部分
  • 漫谈设计模式 [2]:工厂方法模式
  • 动手学深度学习(pytorch)学习记录26-卷积神经网路(LeNet)[学习记录]
  • 基于云函数的自习室预约微信小程序+LW示例参考
  • 560.和为k的子数组