unfold函数
文章目录
- 1. 原理介绍
- 2. pytorch源码验证:
1. 原理介绍
torch.unfold函数的作用是将卷积出来的元素提取后按列向量排列
- 提取元素
- 按列生成矩阵
2. pytorch源码验证:
- pytorch:
import torch.nn as nn
import torch
if __name__ == "__main__":
run_code = 0
my_unfold = nn.Unfold(kernel_size=(2, 2))
batch_size = 1
in_channels = 2
input_h = 3
input_w = 4
my_total = batch_size * in_channels * input_h * input_w
my_shape = (batch_size, in_channels, input_w, input_h)
my_matrix = torch.arange(my_total).reshape(my_shape).to(torch.float)
my_output = my_unfold(my_matrix)
kernel_size = (2,2)
print(f"my_matrix.shape=\n{my_matrix.shape}")
print(f"my_output.shape=\n{my_output.shape}")
print(f"my_matrix=\n{my_matrix}")
print(f"unfold_kernel={kernel_size}")
print(f"my_output=\n{my_output}")
- 结果:
my_matrix.shape=
torch.Size([1, 2, 4, 3])
my_output.shape=
torch.Size([1, 8, 6])
my_matrix=
tensor([[[[ 0., 1., 2.],
[ 3., 4., 5.],
[ 6., 7., 8.],
[ 9., 10., 11.]],
[[12., 13., 14.],
[15., 16., 17.],
[18., 19., 20.],
[21., 22., 23.]]]])
unfold_kernel=(2, 2)
my_output=
tensor([[[ 0., 1., 3., 4., 6., 7.],
[ 1., 2., 4., 5., 7., 8.],
[ 3., 4., 6., 7., 9., 10.],
[ 4., 5., 7., 8., 10., 11.],
[12., 13., 15., 16., 18., 19.],
[13., 14., 16., 17., 19., 20.],
[15., 16., 18., 19., 21., 22.],
[16., 17., 19., 20., 22., 23.]]])