深度学习02-pytorch-07-张量的拼接操作
在 PyTorch 中,张量的拼接操作主要通过 torch.cat()
和 torch.stack()
两个函数来完成。拼接操作允许你将多个张量沿着指定的维度连接在一起,构建更大的张量。以下是详细解释和举例说明:
1. torch.cat()
功能: 沿着指定的维度连接(拼接)多个张量。torch.cat()
是最常用的拼接函数,它不会增加新的维度,只是在指定维度上将张量的值连接在一起。
语法:
torch.cat(tensors, dim=0)
-
tensors
: 需要拼接的张量列表。 -
dim
: 沿着哪一个维度进行拼接。
示例:
import torch
# 创建两个形状相同的张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = torch.tensor([[7, 8, 9], [10, 11, 12]])
# 沿着第0个维度拼接
z0 = torch.cat((x, y), dim=0)
print(z0)
输出:
tensor([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 11, 12]])
在这个例子中,x
和 y
沿着第0维拼接(相当于竖直方向连接)。
# 沿着第1个维度拼接
z1 = torch.cat((x, y), dim=1)
print(z1)
输出:
tensor([[ 1, 2, 3, 7, 8, 9], [ 4, 5, 6, 10, 11, 12]])
在这个例子中,x
和 y
沿着第1维拼接(相当于水平方向连接)。
2. torch.stack()
功能: 沿着新维度拼接多个张量。与 torch.cat()
不同,torch.stack()
会在指定维度插入一个新的维度,并将张量叠加在该维度上。
语法:
torch.stack(tensors, dim=0)
-
tensors
: 需要叠加的张量列表。 -
dim
: 在哪一个维度上插入新的维度。
示例:
import torch
# 创建两个相同形状的张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = torch.tensor([[7, 8, 9], [10, 11, 12]])
# 沿着新维度叠加张量
z = torch.stack((x, y), dim=0)
print(z)
输出:
tensor([[[ 1, 2, 3], [ 4, 5, 6]], [[ 7, 8, 9], [10, 11, 12]]])
在这个例子中,torch.stack()
插入了一个新的维度,最终的形状是 (2, 2, 3)
。
# 沿着第1个维度叠加
z1 = torch.stack((x, y), dim=1)
print(z1)
输出:
tensor([[[ 1, 2, 3], [ 7, 8, 9]], [[ 4, 5, 6], [10, 11, 12]]])
在这个例子中,新的维度插入到了第1维,最终的形状是 (2, 2, 3)
。
3. torch.chunk()
功能: 将一个张量沿着指定的维度分割成若干个小张量。 语法:
torch.chunk(tensor, chunks, dim=0)
-
tensor
: 需要被分割的张量。 -
chunks
: 分割成多少个张量。 -
dim
: 沿着哪一个维度分割。
示例:
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 沿着第0维分割为3块
chunks = torch.chunk(x, 3, dim=0)
for chunk in chunks:
print(chunk)
输出:
tensor([[1, 2, 3]]) tensor([[4, 5, 6]]) tensor([[7, 8, 9]])
4. torch.split()
功能: 与 torch.chunk()
类似,但它允许指定每个子张量的大小。
语法:
torch.split(tensor, split_size_or_sections, dim=0)
-
tensor
: 要分割的张量。 -
split_size_or_sections
: 每个子张量的大小,或按指定的切片。 -
dim
: 沿着哪一个维度分割。
示例:
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 将张量分割为每块大小为2和1
splits = torch.split(x, [2, 1], dim=0)
for split in splits:
print(split)
输出:
tensor([[1, 2, 3], [4, 5, 6]]) tensor([[7, 8, 9]])
5. torch.unbind()
功能: 沿着指定维度将张量解开为多个子张量。 语法:
torch.unbind(tensor, dim=0)
-
tensor
: 要解开的张量。 -
dim
: 沿着哪个维度解开。
示例:
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 沿着第0维度解开
unbinded = torch.unbind(x, dim=0)
for t in unbinded:
print(t)
输出:
tensor([1, 2, 3]) tensor([4, 5, 6])
在这个例子中,张量 x
沿着第0维被解开为两个子张量。
总结
-
torch.cat()
是最常用的拼接方法,用于沿着指定维度拼接多个张量。 -
torch.stack()
可以插入新的维度,叠加多个张量。 -
torch.chunk()
和torch.split()
用于将张量分割成多个子张量。 -
torch.unbind()
用于沿指定维度解开张量为多个张量。
这些操作允许你灵活地操作张量的维度,方便进行数据预处理和模型设计。