【PyTorch】6.张量形状操作:在深度学习的 “魔方” 里,玩转张量形状
目录
1. reshape 函数的用法
2. transpose 和 permute 函数的使用
4. squeeze 和 unsqueeze 函数的用法
5. 小节
个人主页:Icomi
专栏地址:PyTorch入门
在深度学习蓬勃发展的当下,PyTorch 是不可或缺的工具。它作为强大的深度学习框架,为构建和训练神经网络提供了高效且灵活的平台。神经网络作为人工智能的核心技术,能够处理复杂的数据模式。通过 PyTorch,我们可以轻松搭建各类神经网络模型,实现从基础到高级的人工智能应用。接下来,就让我们一同走进 PyTorch 的世界,探索神经网络与人工智能的奥秘。本系列为PyTorch入门文章,若各位大佬想持续跟进,欢迎与我交流互关。
咱们已经学习了张量的花式索引操作,它就像一把精巧的工具,让我们能够在数据的 “宝藏库” 里精准地提取和修改信息。我们接下来要学习—— 掌握对张量形状的操作。
想象一下,我们即将搭建的网络模型就像一座宏伟而复杂的建筑,而数据则是构建这座建筑的基石。这些数据在我们的深度学习世界里,都是以张量的形式存在。在这座 “建筑” 中,不同的楼层(网络层)有着不同的功能和设计,它们之间的数据传递和运算就如同建筑中不同楼层之间的物资运输和协作。
每一层网络对数据的处理方式都不尽相同,这就导致数据在网络层与层之间流动时,会以不同的形状(shape)进行表现和运算。比如说,有的层可能接收的是二维的张量数据,经过处理后输出一个三维的张量,就像把方形的积木经过加工变成了一个立体的模型。
如果我们不掌握对张量形状的操作,就好比一个建筑工人不熟悉不同建筑材料的尺寸和拼接方式,那么在搭建这座 “网络建筑” 时,各层之间的数据连接就会出现问题,就像积木无法正确拼接,最终导致整个建筑摇摇欲坠。
为了能够更好地处理网络各层之间的数据连接,顺利搭建出稳固而强大的网络模型,掌握对张量形状的操作就显得尤为重要。接下来,我们就一同深入学习如何巧妙地调整和管理张量的形状,让我们在深度学习的建筑之路上稳步前行。
1. reshape 函数的用法
reshape 函数可以在保证张量数据不变的前提下改变数据的维度,将其转换成指定的形状,在后面的神经网络学习时,会经常使用该函数来调节数据的形状,以适配不同网络层之间的数据传递。
import torch
import numpy as np
def tensor_shape_operations():
# 创建一个二维张量
tensor = torch.tensor([[10, 20, 30], [40, 50, 60]])
# 1. 使用 shape 属性或者 size 方法都可以获得张量的形状
print(f"使用 shape 属性获取的形状: {tensor.shape},第 0 维大小: {tensor.shape[0]},第 1 维大小: {tensor.shape[1]}")
print(f"使用 size 方法获取的形状: {tensor.size()},第 0 维大小: {tensor.size(0)},第 1 维大小: {tensor.size(1)}")
# 2. 使用 reshape 函数修改张量形状
reshaped_tensor = tensor.reshape(1, 6)
print(f"修改形状后的张量形状: {reshaped_tensor.shape}")
if __name__ == '__main__':
tensor_shape_operations()
2. transpose 和 permute 函数的使用
transpose 函数可以实现交换张量形状的指定维度, 例如: 一个张量的形状为 (2, 3, 4) 可以通过 transpose 函数把 3 和 4 进行交换, 将张量的形状变为 (2, 4, 3)
permute 函数可以一次交换更多的维度。
import torch
import numpy as np
def test():
data = torch.tensor(np.random.randint(0, 10, [3, 4, 5]))
print('data shape:', data.size())
# 1. 交换1和2维度
new_data = torch.transpose(data, 1, 2)
print('data shape:', new_data.size())
# 2. 将 data 的形状修改为 (4, 5, 3)
new_data = torch.transpose(data, 0, 1)
new_data = torch.transpose(new_data, 1, 2)
print('new_data shape:', new_data.size())
# 3. 使用 permute 函数将形状修改为 (4, 5, 3)
new_data = torch.permute(data, [1, 2, 0])
print('new_data shape:', new_data.size())
if __name__ == '__main__':
test()
4. squeeze 和 unsqueeze 函数的用法
squeeze 函数用删除 shape 为 1 的维度,unsqueeze 在每个维度添加 1, 以增加数据的形状
import torch
import numpy as np
def test():
data = torch.tensor(np.random.randint(0, 10, [1, 3, 1, 5]))
print('data shape:', data.size())
# 1. 去掉值为1的维度
new_data = data.squeeze()
print('new_data shape:', new_data.size()) # torch.Size([3, 5])
# 2. 去掉指定位置为1的维度,注意: 如果指定位置不是1则不删除
new_data = data.squeeze(2)
print('new_data shape:', new_data.size()) # torch.Size([3, 5])
# 3. 在2维度增加一个维度
new_data = data.unsqueeze(-1)
print('new_data shape:', new_data.size()) # torch.Size([3, 1, 5, 1])
if __name__ == '__main__':
test()
5. 小节
本小节我们学习了经常使用的关于张量形状的操作,我们用到的主要函数有:
- reshape 函数可以在保证张量数据不变的前提下改变数据的维度.
- transpose 函数可以实现交换张量形状的指定维度, permute 可以一次交换更多的维度.
- view 函数也可以用于修改张量的形状, 但是它要求被转换的张量内存必须连续,所以一般配合 contiguous 函数使用.
- squeeze 和 unsqueeze 函数可以用来增加或者减少维度.