squeeze()
squeeze
函数是 PyTorch 中的一个函数,用于从张量(Tensor)中去除所有长度为 1 的维度。这在处理神经网络模型的输出时非常有用,因为模型的输出可能包含一些不必要的单一维度,例如在批处理大小为 1 的情况下,输出可能会有一个额外的批次维度。
简单来说就是从张量(Tensor)中去除所有长度为 1 的维度。
实例:
import torch
x = torch.randn(1, 2, 1, 2, 2)
print(x.shape) # 输出: torch.Size([1, 2, 1, 2, 2])
y = torch.squeeze(x)
print(y.shape) # 输出: torch.Size([2, 2, 2])
z = torch.squeeze(x, dim=1)
print(z.shape) # 输出: torch.Size([1, 2, 1, 2, 2])
v= torch.squeeze(x, dim=0)
print(v.shape) # 输出: torch.Size([1, 2, 2, 2])
输出:
#x:
torch.Size([1, 2, 1, 2, 2])
#y:不指定维度就去除所有长度为 1 的维度
torch.Size([2, 2, 2])#z:指定维度长度不为1则无效
torch.Size([1, 2, 1, 2, 2])#v:指定维度长度为1生效
torch.Size([2, 1, 2, 2])