15. 数据维度转换 -- torch.reshape
数据维度转换 torch.reshape()
1. 为什么要使用 reshape()
函数
- 对于不同的网络结构如:一维卷积核、二维卷积核等,对输入数据维度要求并不相同,
reshape()
函数提供了非常方便的数据维度转换功能 torch.reshape()
提供了数据维度转换功能,在使用对数据维度有一定限制的网络结构时,一定要注意维度问题!!
2. 维度问题实例
-
nn.Conv2d()
接受的数据的维度必须是 (N, C, H, W)四维 或者 (C, H, W)三维的Tensor数据,对于下面的样例,就是因为数据维度问题而报错torch.manual_seed(0) input = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # torch.Size([3, 3]) conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(2, 2),) output = conv(input) # RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [3, 3]
(N, C, H, W):
N表示多少数据数量,C表示数据的通道数,H表示数据的行数,W表示数据的列数 <