numpy中的transpose()和pytorch中的permute()
它们都用于重新排列张量的维度,但在实现细节和使用方式上有所不同
numpy.transpose
numpy.transpose 函数用于重新排列数组的维度。它接受一个元组作为参数,表示新的维度顺序。
numpy.transpose(a, axes=None)
a:输入的数组。
axes:一个整数元组,表示新的维度顺序。如果 None,则默认反转所有维度。
- 示例
import numpy as np
# 创建一个形状为 (2, 3, 4) 的数组
x = np.random.randn(2, 3, 4)
# 重新排列维度为 (2, 0, 1)
y = np.transpose(x, (2, 0, 1))
print(x.shape) # 输出: (2, 3, 4)
print(y.shape) # 输出: (4, 2, 3)
在这个例子中,np.transpose(x, (2, 0, 1)) 将数组 x 的维度重新排列为 (2, 0, 1),结果数组 y 的形状变为 (4, 2, 3)。
torch.permute
torch.permute 函数用于重新排列张量的维度。它接受一个元组作为参数,表示新的维度顺序。
torch.permute(input, dims)
input:输入的张量。
dims:一个整数元组,表示新的维度顺序。
- 示例
import torch
# 创建一个形状为 (2, 3, 4) 的张量
x = torch.randn(2, 3, 4)
# 重新排列维度为 (2, 0, 1)
y = torch.permute(x, (2, 0, 1))
print(x.shape) # 输出: torch.Size([2, 3, 4])
print(y.shape) # 输出: torch.Size([4, 2, 3])
在这个例子中,torch.permute(x, (2, 0, 1)) 将张量 x 的维度重新排列为 (2, 0, 1),结果张量 y 的形状变为 (4, 2, 3)。
区别总结
-
功能相似:
numpy.transpose 和 torch.permute 都用于重新排列数组或张量的维度。 -
默认行为:
numpy.transpose:如果没有指定 axes 参数,默认反转所有维度。
torch.permute:必须显式指定新的维度顺序。 -
使用方式:
numpy.transpose:通常用于 numpy 数组。
torch.permute:通常用于 torch 张量。
总结
-
numpy.transpose 和 torch.permute 在功能上相似,都用于重新排列数组或张量的维度。
-
numpy.transpose 默认反转所有维度,而 torch.permute 必须显式指定新的维度顺序。