pytorch torch.tile用法
指定各维度分别重复多少次
tile
是 PyTorch 中用于重复张量的函数。它可以沿指定的维度重复张量的元素。以下是一个示例代码,展示 tile
的用法:
import torch
# 创建一个张量
weight_hh = torch.tensor([[1, 2], [3, 4]])
# 假设批量大小为3
bs = 3
# 使用 unsqueeze 在第0维度增加一个维度,然后使用 tile 沿第0维度重复 bs 次
w_hh_batch = weight_hh.unsqueeze(0).tile(bs, 1, 1)
print("原始张量:")
print(weight_hh)
print("增加维度并重复后的张量:")
print(w_hh_batch)
在这个示例中:
weight_hh
是一个形状为[2, 2]
的张量。weight_hh.unsqueeze(0)
在第0维度增加一个维度,使其形状变为[1, 2, 2]
。tile(bs, 1, 1)
沿第0维度重复bs
次(这里bs
为3),使其形状变为[3, 2, 2]
。
原始张量:
tensor([[1, 2],
[3, 4]])
增加维度并重复后的张量:
tensor([[[1, 2],
[3, 4]],
[[1, 2],
[3, 4]],
[[1, 2],
[3, 4]]])
这样,w_hh_batch
就是一个形状为 [3, 2, 2]
的张量,其中每个批次都包含原始的 weight_hh
张量