【学习笔记】深入浅出详解Pytorch中的View, reshape, unfold,flatten等方法。
文章目录
- 一、写在前面
- 二、Reshape
- (一)用法
- (二)代码展示
- 三、Unfold
- (一)torch.unfold 的基本概念
- (二)torch.unfold 的工作原理
- (三) 示例代码
- (四)torch.unfold 的应用场景
- (五)注意事项
- (六)总结
- 四、View
- (一)用法
- (二)注意事项
- (三)其他方法
- 五、Flatten
- (一)torch.flatten 的基本概念
- (二)torch.flatten 的工作原理
- (三)示例代码
- 六、Permute
- (一)torch.permute 的基本概念
- (二)torch.permute 的工作原理
- (三) 示例代码
- (四) torch.permute 的应用场景
- 七、总结
一、写在前面
最近在解析transformer源码的时候突然看到了unfold?我在想unfold是什么意思?为什么不用reshape,他们的底层逻辑有什么区别呢?于是便相对比一下他们之间的区别,便有了本篇博客,希望对大家有帮助!
二、Reshape
(一)用法
1. torch.reshape(input, shape)
输入是tensor和shape,其中原始shape和目标shape的元素数量要一致。
2. Tensor.reshape(shape) → Tensor
与上述用法一致,只不过这个是直接在tensor的基础上进行reshape。
reshpe是按照顺序进行重新排列组合的。
[16,2] 其实与 [4,2,2,2] 是一样的,只要最后一维的数字是一样,其实结果都是一样的。
(二)代码展示
三、Unfold
(一)torch.unfold 的基本概念
torch.unfold 的作用是将输入张量的某个维度展开为多个滑动窗口。每个窗口包含一个局部区域,这些区域可以用于后续的计算。
- 语法:
torch.Tensor.unfold(dimension, size, step)
- 参数:
- dimension:要展开的维度(整数)。
- size:滑动窗口的大小(整数)。
- step:滑动窗口的步长(整数)。
- 返回值:返回一个新的张量,其中指定维度的每个元素被展开为多个滑动窗口。
(二)torch.unfold 的工作原理
假设我们有一个形状为 (N, C, H, W) 的张量(例如图像数据),我们希望在高度维度(H)上提取滑动窗口。
- 输入张量:(N, C, H, W)
- 展开维度:dimension=2(即高度维度 H)
- 窗口大小:size=k(例如 k=3)
- 步长:step=s(例如 s=1)
torch.unfold 会将高度维度 H 展开为多个大小为 k 的滑动窗口,每个窗口之间间隔 s。
(三) 示例代码
- 示例 1:一维张量的展开
import torch
# 创建一个一维张量
x = torch.arange(10)
print("原始张量:", x)
# 使用 unfold 展开
unfolded = x.unfold(dimension=0, size=3, step=1)
print("展开后的张量:\n", unfolded)
- 输出:
原始张量: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
展开后的张量:
tensor([[0, 1, 2],
[1, 2, 3],
[2, 3, 4],
[3, 4, 5],
[4, 5, 6],
[5, 6, 7],
[6, 7, 8],
[7, 8, 9]])
- 示例 2:二维张量的展开(图像处理)
import torch
# 创建一个二维张量(模拟图像)
x = torch.arange(16).reshape(1, 1, 4, 4) # 形状为 (1, 1, 4, 4)
print("原始张量:\n", x)
# 使用 unfold 展开
unfolded_1 = x.unfold(dimension=2, size=3, step=1)
unfolded_2 = unfolded_1.unfold(dimension=3, size=3, step=1)
print("展开后的张量形状:", unfolded_1.shape)
print("展开后的张量:\n", unfolded_1)
print("展开后的张量形状:", unfolded_2.shape)
print("展开后的张量:\n", unfolded_2)
- 输出:
原始张量:
tensor([[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]]]])
展开后的张量形状: torch.Size([1, 1, 2, 4, 3])
展开后的张量:
tensor([[[[[ 0, 4, 8],
[ 1, 5, 9],
[ 2, 6, 10],
[ 3, 7, 11]],
[[ 4, 8, 12],
[ 5, 9, 13],
[ 6, 10, 14],
[ 7, 11, 15]]]]])
展开后的张量形状: torch.Size([1, 1, 2, 2, 3, 3])
展开后的张量:
tensor([[[[[[ 0, 1, 2],
[ 4, 5, 6],
[ 8, 9, 10]],
[[ 1, 2, 3],
[ 5, 6, 7],
[ 9, 10, 11]]],
[[[ 4, 5, 6],
[ 8, 9, 10],
[12, 13, 14]],
[[ 5, 6, 7],
[ 9, 10, 11],
[13, 14, 15]]]]]])
数组的运算主要看最后两维,倒数第二维代表行,倒数第一维代表列。
(四)torch.unfold 的应用场景
- 卷积操作
在卷积神经网络(CNN)中,卷积核通过滑动窗口的方式提取图像的局部特征。torch.unfold 可以用于手动实现卷积操作。
- 图像处理
在图像处理中,torch.unfold 可以用于提取图像的局部区域(例如提取图像的滑动窗口)。
- 时间序列分析
在时间序列分析中,torch.unfold 可以用于提取时间序列的滑动窗口,用于特征提取或模型训练。
(五)注意事项
- 维度选择:需要明确指定要展开的维度(dimension)。
- 窗口大小和步长:窗口大小(size)和步长(step)的选择会影响展开后的张量形状。
- 内存消耗:torch.unfold 可能会生成较大的张量,尤其是在高维数据上使用时,需要注意内存消耗。
(六)总结
- torch.unfold 的作用:从张量的某个维度提取滑动窗口。
- 常用参数:dimension(展开维度)、size(窗口大小)、step(步长)。
- 应用场景:卷积操作、图像处理、时间序列分析等。
- 注意事项:选择合适的维度、窗口大小和步长,避免内存消耗过大。
四、View
torch.view 用于返回一个与原始张量共享相同数据存储的新张量,但具有不同的形状。换句话说,view 只是改变了张量的视图(view),而不会复制数据。
(一)用法
- Tensor.view(*shape):
- 参数:
*shape:新的形状(可以是整数或元组)。
- 返回值:
返回一个新的张量,具有指定的形状,并与原始张量共享相同的数据存储。
(二)注意事项
- view 返回的张量与原始张量共享相同的数据存储。
- 如果原始张量的数据发生变化,view 返回的张量也会随之变化。
- view 要求张量的内存必须是连续的(即张量在内存中是连续存储的)。如果内存不连续,view 会抛出错误。
(三)其他方法
- view_as
- 作用:将当前张量转换为与另一个张量相同的形状。
- 语法:
torch.Tensor.view_as(other)
- 参数:
other:目标张量,当前张量的形状将被转换为与 other 相同的形状。
- 返回值:
返回一个新的张量,具有与 other 相同的形状,并与原始张量共享数据存储。
- 示例:
import torch
# 创建一个张量
x = torch.arange(12)
print("原始张量:", x)
# 创建目标张量
other = torch.empty(3, 4)
# 使用 view_as 改变形状
y = x.view_as(other)
print("改变形状后的张量:\n", y)
- 输出:
原始张量: tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
改变形状后的张量:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
- view_as_real
- 作用:将复数张量转换为实数张量。
- 语法:
torch.Tensor.view_as_real()
- 返回值:
返回一个新的实数张量,形状为 (…, 2),其中最后一个维度包含复数的实部和虚部。
- 示例:
import torch
# 创建一个复数张量
x = torch.tensor([1 + 2j, 3 + 4j])
print("原始张量:", x)
# 使用 view_as_real 转换为实数张量
y = x.view_as_real()
print("转换后的张量:\n", y)
- 输出:
原始张量: tensor([1.+2.j, 3.+4.j])
转换后的张量:
tensor([[1., 2.],
[3., 4.]])
- view_as_complex
- 作用:将实数张量转换为复数张量。
- 语法:
torch.Tensor.view_as_complex()
- 返回值:
返回一个新的复数张量,形状为 (…,),其中最后一个维度被解释为复数的实部和虚部。
- 示例:
import torch
# 创建一个实数张量
x = torch.tensor([[1., 2.], [3., 4.]])
print("原始张量:\n", x)
# 使用 view_as_complex 转换为复数张量
y = x.view_as_complex()
print("转换后的张量:", y)
- 输出:
原始张量:
tensor([[1., 2.],
[3., 4.]])
转换后的张量: tensor([1.+2.j, 3.+4.j])
- view_as_strided
- 作用:返回一个具有指定步长和内存布局的张量视图。
- 语法:
torch.Tensor.view_as_strided(size, stride)
- 参数:
size:新的形状(元组)。
stride:新的步长(元组)。
- 返回值:
返回一个新的张量,具有指定的形状和步长,并与原始张量共享数据存储。
- 示例:
import torch
# 创建一个张量
x = torch.arange(9).view(3, 3)
print("原始张量:\n", x)
# 使用 view_as_strided 改变形状和步长
y = x.view_as_strided((2, 2), (1, 2))
print("改变形状和步长后的张量:\n", y)
- 输出:
原始张量:
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
改变形状和步长后的张量:
tensor([[0, 2],
[1, 3]])
- view_as_real 和 view_as_complex 的应用场景
- 复数计算:
在涉及复数计算的任务中,view_as_real 和 view_as_complex 可以用于在复数和实数之间进行转换。
- 信号处理:
在信号处理中,复数张量常用于表示频域信号,view_as_real 和 view_as_complex 可以用于频域和时域之间的转换。
- view_as_strided 的应用场景
- 自定义内存布局:
在需要自定义内存布局的场景中,view_as_strided 可以用于创建具有特定步长和形状的张量视图。
- 高效内存访问:
在需要高效访问内存的场景中,view_as_strided 可以用于优化内存访问模式。
五、Flatten
(一)torch.flatten 的基本概念
torch.flatten 的作用是将输入张量的指定维度展平为一维。它可以展平整个张量,也可以只展平部分维度。
- 语法:
torch.flatten(input, start_dim=0, end_dim=-1)
-
参数:
- input:输入张量。
- start_dim:开始展平的维度(整数),默认为 0。
- end_dim:结束展平的维度(整数),默认为 -1。
- 返回值:返回一个新的张量,具有展平后的形状。
(二)torch.flatten 的工作原理
- torch.flatten 会将指定范围内的维度展平为一维。
- 如果 start_dim=0 且 end_dim=-1,则整个张量会被展平为一维。
- 如果只展平部分维度,则其他维度保持不变。
(三)示例代码
- 示例 1:展平整个张量
import torch
# 创建一个二维张量
x = torch.arange(12).view(3, 4)
print("原始张量:\n", x)
# 使用 flatten 展平整个张量
y = torch.flatten(x)
print("展平后的张量:", y)
- 输出:
原始张量:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
展平后的张量: tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
- 示例 2:展平部分维度
import torch
# 创建一个三维张量
x = torch.arange(24).view(2, 3, 4)
print("原始张量:\n", x)
# 使用 flatten 展平第二个维度到第三个维度
y = torch.flatten(x, start_dim=1, end_dim=2)
print("展平后的张量:\n", y)
- 输出:
原始张量:
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
展平后的张量:
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]])
- 示例 3:展平特定维度
import torch
# 创建一个四维张量
x = torch.arange(24).view(2, 2, 3, 2)
print("原始张量:\n", x)
# 使用 flatten 展平第三个维度
y = torch.flatten(x, start_dim=2, end_dim=2)
print("展平后的张量:\n", y)
- 输出:
原始张量:
tensor([[[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]]],
[[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]]]])
展平后的张量:
tensor([[[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11]],
[[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23]]])
六、Permute
(一)torch.permute 的基本概念
torch.permute 的作用是将输入张量的维度按照指定的顺序重新排列。它类似于 NumPy 中的 transpose,但更加灵活,可以同时对多个维度进行排列。
- 语法:
torch.Tensor.permute(*dims)
- 参数:
- *dims:新的维度顺序(元组或多个整数)。
- 返回值:返回一个新的张量,具有重新排列后的维度顺序,并与原始张量共享数据存储。
(二)torch.permute 的工作原理
torch.permute 会将输入张量的维度按照指定的顺序重新排列。
新的维度顺序必须与原始张量的维度数量相同,并且每个维度索引只能出现一次。
(三) 示例代码
- 示例 1:二维张量的转置
import torch
# 创建一个二维张量
x = torch.arange(12).view(3, 4)
print("原始张量:\n", x)
# 使用 permute 转置
y = x.permute(1, 0) # 将维度 0 和 1 交换
print("转置后的张量:\n", y)
- 输出:
原始张量:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
转置后的张量:
tensor([[ 0, 4, 8],
[ 1, 5, 9],
[ 2, 6, 10],
[ 3, 7, 11]])
- 示例 2:三维张量的维度重排
import torch
# 创建一个三维张量
x = torch.arange(24).view(2, 3, 4)
print("原始张量:\n", x)
# 使用 permute 重排维度
y = x.permute(2, 0, 1) # 将维度 0, 1, 2 重排为 2, 0, 1
print("重排后的张量形状:", y.shape)
print("重排后的张量:\n", y)
- 输出:
原始张量:
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
重排后的张量形状: torch.Size([4, 2, 3])
重排后的张量:
tensor([[[ 0, 4, 8],
[12, 16, 20]],
[[ 1, 5, 9],
[13, 17, 21]],
[[ 2, 6, 10],
[14, 18, 22]],
[[ 3, 7, 11],
[15, 19, 23]]])
- 示例 3:四维张量的维度重排
import torch
# 创建一个四维张量
x = torch.arange(24).view(2, 2, 3, 2)
print("原始张量:\n", x)
# 使用 permute 重排维度
y = x.permute(3, 1, 2, 0) # 将维度 0, 1, 2, 3 重排为 3, 1, 2, 0
print("重排后的张量形状:", y.shape)
print("重排后的张量:\n", y)
- 输出:
原始张量:
tensor([[[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]]],
[[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]]]])
重排后的张量形状: torch.Size([2, 2, 3, 2])
重排后的张量:
tensor([[[[ 0, 12],
[ 2, 14],
[ 4, 16]],
[[ 6, 18],
[ 8, 20],
[10, 22]]],
[[[ 1, 13],
[ 3, 15],
[ 5, 17]],
[[ 7, 19],
[ 9, 21],
[11, 23]]]])
(四) torch.permute 的应用场景
- 图像处理
在图像处理中,torch.permute 可以用于调整图像的通道顺序。例如,将 (C, H, W) 的图像张量转换为 (H, W, C)。
- 深度学习模型输入
在深度学习中,模型的输入张量通常需要特定的维度顺序。torch.permute 可以用于调整输入张量的维度顺序。
- 数据预处理
在数据预处理中,torch.permute 可以用于调整数据的维度顺序,以便进行后续的计算或操作。
如果每一维度都有特定的含义,此时想改变维度的时候用permute。
七、总结
- reshape 与 view 几乎一致, 甚至可以说reshape可以代替view;
- reshape 与 permute 的区别在于,reshape是按照顺序重新进行排列组合,permute是按照维度进行重新排列组合,如果各个维度都有特定的意义,那么permute会更合适。
- unfold是按照某一维度滑动选取数据,新增加一维,新增加的维度的大小为滑动窗口的大小,原始维度会根据滑动窗口和step的大小而变化。
- flatten也是按照顺序进行展开,且展开的是某个范围的维度,不是特定的维度。
- 不管是几维数组,主要看最后两维,倒数第二维代表行,倒数第一维代表列。