深度学习pytorch——Tensor维度变换(持续更新)
view()打平函数
需要注意的是打平之后的tensor是需要有物理意义的,根据需要进行打平,并且打平后总体的大小是不发生改变的。
并且一定要谨记打平会导致维度的丢失,造成数据污染,如果想要恢复到原来的数据形式,是需要靠人为记忆的。
现在给出一个tensor——a.shape=torch.Size([4, 1, 28, 28]),打平a.view(4,1*28*28),此时a.view(4,1*28*28).shape=torch.Size([4, 784])。
当然也可以向高维度:b.shape=torch.Size([4, 784]),打平b.view(4,28,28,1),此时b.view(4,28,28,1).shape=torch.Size([4, 28, 28, 1])
unsqueeze()维度增加
当使用unsqueeze()方法时,此时概念会发生改变,会为数据增加一个组别,这个组别的含义由自己定义。
语法:unsqueeze(index) 如果index为正,则在索引之前加入;如果index为负,则在索引之后加入
代码演示:
# a.shape : torch.Size([4, 1, 28, 28])
# index 为正
print(a.unsqueeze(0).shape)
# torch.Size([1, 4, 1, 28, 28])
print(a.unsqueeze(3).shape)
# torch.Size([4, 1, 28, 1, 28])
# index 为负
print(a.unsqueeze(-1).shape)
# torch.Size([4, 1, 28, 28, 1])
print(a.unsqueeze(-2).shape)
# torch.Size([4, 1, 28, 1, 28])
# 注意不要超出索引范围,否则会报错
# 增加组别具体在数据上的表现
b = torch.tensor([1.2,2.3]) # 此时是一个dim为1,size为2的tensor
print(b.unsqueeze(-1)) # 是在最里层添加了一个维度
# tensor([[1.2000],
# [2.3000]])
print(b.unsqueeze(0)) # 是在最外层添加了一个维度
# tensor([[1.2000, 2.3000]])
来个小例子:
# for example
# bias相当于给每个channel上的像素增加了一个偏置
b = torch.rand(32)
f = torch.rand(4,32,14,14)
# 现在我们要实现b+f,由于二者维度不同,不能操作(每个维度对应的size也要相同)
b = b.unsqueeze(1) # torch.Size([32, 1])
print(b.shape)
b = b.unsqueeze(2) # torch.Size([32, 1, 1])
print(b.shape)
b = b.unsqueeze(0) # torch.Size([1, 32, 1, 1])
print(b.shape)
squeeze()维度减少
语法:squeeze(index) 如果index不填写,就是将所有size都为1的都去除;index就是去除对应的维度,但是只有size=1的才能被去除。
代码演示:
# b.shape = torch.Size([1, 32, 1, 1])
print(b.squeeze().shape) # torch.Size([32]) 如果不添加任何参数,就是将所有size=1的都去除
print(b.squeeze(0).shape) # torch.Size([32, 1, 1])
print(b.squeeze(1).shape) # torch.Size([1, 32, 1, 1]) 只有size=1的才能被去除
expand()
条件:维度一致,并且只有size=1的才能扩张。
使用蓝色线画的数据必须保持一致;如果参数为-1,则意味着size保持不变。
repeat()
repeat()复制内存数据,括号内参数是copy次数。
代码演示:
print(b.repeat(4,32,1,1).shape) # torch.Size([4, 1024, 1, 1])
转置
1、.t 矩阵转置,只适用于矩阵
2、transpose(dim1,dim2)转置
语法:交换dim1,dim2两个维度。注意transpose()方法会将数据变得不连续,所以通常需要借助于
contiguous()方法,用于将数据变得连续。
数据维度顺序必须和存储顺序一致。(说实话这一句我不太懂,然后我就去问了一下chatgpt)答案:
"数据维度顺序必须和存储顺序一致"是指在使用PyTorch进行数据处理和存储时,数据的维度顺序必须与存储的顺序一致。如果数据的维度顺序与存储的顺序不一致,可能会导致数据处理错误或结果不准确。
例如,如果使用PyTorch创建一个张量(tensor)并在存储时按默认的规则进行存储,即按行优先顺序存储,那么在对该张量进行操作时,需要按照相同的维度顺序进行操作,否则可能会导致错误。
总之,这句话的意思是在PyTorch中,需要保证数据的维度顺序和存储顺序一致,以确保数据处理和存储的正确性。
代码演示:
# a.shape=torch.Size([2, 3, 5, 5])
a1 = a.transpose(1,3).contiguous().view(2,3*5*5).view(2,3,5,5)
a2 = a.transpose(1,3).contiguous().view(2,3*5*5).view(2,5,5,3).transpose(1,3)
print(a1.shape,a2.shape) # torch.Size([2, 3, 5, 5]) torch.Size([2, 3, 5, 5])
print(torch.all(torch.eq(a,a1))) # tensor(False)
print(torch.all(torch.eq(a,a2))) # tensor(True)
补充:其中all()方法是用来确定所有内容一致,eq()方法是用来比较数据一致。
说实话这个我也不是很懂,但是我去做了一下实验,将torch.eq(a,a2)和torch.eq(a,a1)都打印了出来,发现这是一个shape为torch.Size([2, 3, 5, 5])的张量,并且里面的数据都是ture或者false,然后我就明白了,原来eq()是用来比较对应的每个数据是否相同,all()是用来比较一个张量里面的所有值是否在相同。
permute()
个人认为这个方法非常强大,可以完成任意维度的交换。我们先来看一个使用transpose()方法进行维度交换:
# b.shape=torch.Size([4, 3, 28, 32])
print(b.transpose(1,3).shape) # torch.Size([4, 32, 28, 3])
print(b.transpose(1,3).transpose(1,2).shape) # torch.Size([4, 28, 32, 3])
再来看一下permute()方法:
# b.shape=torch.Size([4, 3, 28, 32])
print(b.permute(0,2,3,1).shape) # torch.Size([4, 28, 32, 3])
有没有感觉很强大。