当前位置: 首页 > article >正文

unfold函数

文章目录

  • 1. 原理介绍
  • 2. pytorch源码验证:

1. 原理介绍

torch.unfold函数的作用是将卷积出来的元素提取后按列向量排列

  • 提取元素
    在这里插入图片描述
  • 按列生成矩阵
    在这里插入图片描述

2. pytorch源码验证:

  • pytorch:
import torch.nn as nn
import torch

if __name__ == "__main__":
    run_code = 0
    my_unfold = nn.Unfold(kernel_size=(2, 2))
    batch_size = 1
    in_channels = 2
    input_h = 3
    input_w = 4
    my_total = batch_size * in_channels * input_h * input_w
    my_shape = (batch_size, in_channels, input_w, input_h)
    my_matrix = torch.arange(my_total).reshape(my_shape).to(torch.float)
    my_output = my_unfold(my_matrix)
    kernel_size = (2,2)
    print(f"my_matrix.shape=\n{my_matrix.shape}")
    print(f"my_output.shape=\n{my_output.shape}")
    print(f"my_matrix=\n{my_matrix}")
    print(f"unfold_kernel={kernel_size}")
    print(f"my_output=\n{my_output}")
  • 结果:
my_matrix.shape=
torch.Size([1, 2, 4, 3])
my_output.shape=
torch.Size([1, 8, 6])
my_matrix=
tensor([[[[ 0.,  1.,  2.],
          [ 3.,  4.,  5.],
          [ 6.,  7.,  8.],
          [ 9., 10., 11.]],

         [[12., 13., 14.],
          [15., 16., 17.],
          [18., 19., 20.],
          [21., 22., 23.]]]])
unfold_kernel=(2, 2)
my_output=
tensor([[[ 0.,  1.,  3.,  4.,  6.,  7.],
         [ 1.,  2.,  4.,  5.,  7.,  8.],
         [ 3.,  4.,  6.,  7.,  9., 10.],
         [ 4.,  5.,  7.,  8., 10., 11.],
         [12., 13., 15., 16., 18., 19.],
         [13., 14., 16., 17., 19., 20.],
         [15., 16., 18., 19., 21., 22.],
         [16., 17., 19., 20., 22., 23.]]])

http://www.kler.cn/a/508991.html

相关文章:

  • Navicat 17 功能简介 | 商业智能 BI
  • C语言进阶习题【1】指针和数组(3)——一维指针指向字符数组首元素地址
  • 从 0 开始实现一个 SpringBoot + Vue 项目
  • Transformer创新模型!Transformer+BO-SVR多变量回归预测,添加气泡图、散点密度图(Matlab)
  • 如何在服务器同一个端口下根据路径区分不同的应用
  • Chrome谷歌浏览器如何能恢复到之前的旧版本
  • 什么是长连接?Netty如何设置进行长连接?
  • Docker详解与部署微服务实战
  • Ansible深度解析:如何精准区分并选用command与shell模块
  • Ruby语言的数据库交互
  • Redis 设计与实现:深入理解高性能缓存数据库
  • 【逆境中绽放:万字回顾2024我在挑战中突破自我】
  • 数据结构——堆(介绍,堆的基本操作、堆排序)
  • 【MySQL】简单解析一条SQL更新语句的执行过程
  • Zookeeper 配置文件:核心参数优化与实操指南
  • csp22前2题
  • JWT(数据结构、认证流程、加密、解密过程)、对称加密和非对称加密
  • C++从入门到实战(一)C++入门基础
  • WPS数据分析000001
  • SparkSQL函数
  • ComfyUI 矩阵测试指南:用三种方法,速优项目效果
  • 适配器模式详解:解决接口不兼容问题的灵活设计模式
  • 如何修改React 项目版本
  • 21天学通C++——11多态(引入多态的目的)
  • 用户中心项目教程(二)---umi3的使用出现的错误
  • 通过idea创建的springmvc工程需要的配置