当前位置: 首页 > article >正文

手动计算conv1d 及pytorch源码

文章目录

  • 1. conv1d
  • 2. pytorch 源码

1. conv1d

conv1d的作用是进行一维的卷积计算,将卷积核沿着输入矩阵进行一维卷积,具体参考如下excel
通过网盘分享的文件:conv1d.xlsx
链接: https://pan.baidu.com/s/1WIM4Pp5nwa-uP67TMP-m8Q?pwd=uti7 提取码: uti7

2. pytorch 源码

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_printoptions(precision=3, sci_mode=False)

if __name__ == "__main__":
    run_code = 0
    # in_channels: int,
    # out_channels: int,
    # kernel_size: _size_1_t,
    # stride: _size_1_t = 1,
    # padding: Union[str, _size_1_t] = 0,
    # dilation: _size_1_t = 1,
    # groups: int = 1,
    # bias: bool = True,
    # padding_mode: str = 'zeros',  # TODO: refine this type
    # device = None,
    # dtype = None
    in_channels = 3
    out_channels = 4
    kernel_size = 2
    stride = 1
    my_weight_total = out_channels * in_channels * kernel_size
    my_weight = torch.arange(my_weight_total).reshape((out_channels, in_channels, kernel_size)).to(torch.float32)
    my_conv1d = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride)
    my_bias = torch.arange(out_channels).to(torch.float32)
    print(my_conv1d)
    print(my_conv1d.weight.shape)
    my_conv1d.weight = nn.Parameter(my_weight)
    my_conv1d.bias = nn.Parameter(my_bias)

    for p in my_conv1d.named_parameters():
        print(p)

    batch_size = 2
    seq_len = 5
    in_total = batch_size * in_channels * seq_len
    in_matrix = torch.arange(in_total).reshape(batch_size, in_channels, seq_len).to(torch.float32)
    out_matrix = my_conv1d(in_matrix)
    conv1d_weight = my_conv1d.weight
    print(f"conv1d_weight.shape=\n{conv1d_weight.shape}")
    print(f"conv1d_weight=\n{conv1d_weight}")
    print(f"in_matrix.shape=\n{in_matrix.shape}")
    print(f"in_matrix=\n{in_matrix}")
    print(f"out_matrix.shape=\n{out_matrix.shape}")
    print(f"out_matrix=\n{out_matrix}")
  • 结果:
Conv1d(3, 4, kernel_size=(2,), stride=(1,))
torch.Size([4, 3, 2])
('weight', Parameter containing:
tensor([[[ 0.,  1.],
         [ 2.,  3.],
         [ 4.,  5.]],

        [[ 6.,  7.],
         [ 8.,  9.],
         [10., 11.]],

        [[12., 13.],
         [14., 15.],
         [16., 17.]],

        [[18., 19.],
         [20., 21.],
         [22., 23.]]], requires_grad=True))
('bias', Parameter containing:
tensor([0., 1., 2., 3.], requires_grad=True))
conv1d_weight.shape=
torch.Size([4, 3, 2])
conv1d_weight=
Parameter containing:
tensor([[[ 0.,  1.],
         [ 2.,  3.],
         [ 4.,  5.]],

        [[ 6.,  7.],
         [ 8.,  9.],
         [10., 11.]],

        [[12., 13.],
         [14., 15.],
         [16., 17.]],

        [[18., 19.],
         [20., 21.],
         [22., 23.]]], requires_grad=True)
in_matrix.shape=
torch.Size([2, 3, 5])
in_matrix=
tensor([[[ 0.,  1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.,  9.],
         [10., 11., 12., 13., 14.]],

        [[15., 16., 17., 18., 19.],
         [20., 21., 22., 23., 24.],
         [25., 26., 27., 28., 29.]]])
out_matrix.shape=
torch.Size([2, 4, 4])
out_matrix=
tensor([[[ 124.,  139.,  154.,  169.],
         [ 323.,  374.,  425.,  476.],
         [ 522.,  609.,  696.,  783.],
         [ 721.,  844.,  967., 1090.]],

        [[ 349.,  364.,  379.,  394.],
         [1088., 1139., 1190., 1241.],
         [1827., 1914., 2001., 2088.],
         [2566., 2689., 2812., 2935.]]], grad_fn=<ConvolutionBackward0>)

http://www.kler.cn/a/533703.html

相关文章:

  • 【现代深度学习技术】深度学习计算 | 延后初始化自定义层
  • 更换IP属地会影响网络连接速度吗
  • 新春贺岁,共赴AGI之旅
  • 鸿蒙Harmony-Refresh 容器组件
  • 【Envi遥感图像处理】009:envi5.6设置中文界面的方法
  • 【蓝桥杯】日志统计
  • 【Mybatis Plus】JSqlParser解析sql语句
  • 子集问题(LeetCode 78 90)
  • js-对象-Array数组
  • 机理模型与数据模型融合的方式
  • 深度探索未来的搜索引擎 —— DeepSeek
  • 请解释 Java 中的 IO 和 NIO 的区别,以及 NIO 如何实现多路复用?
  • 如何在页面中弹出菜单
  • 《2025,AI重塑世界进行时》
  • 【R语言】写入数据
  • 基于PostGIS的省域空间相邻检索实践
  • C语言程序设计P6-3【应用指针进行程序设计 | 第三节】——知识要点:指针与数组
  • 【大数据技术】搭建完全分布式高可用大数据集群(Scala+Spark)
  • LLM推理--vLLM解读
  • 代码讲解系列-CV(二)——卷积神经网络
  • 动态图推理问答算法
  • 动态规划练习八(01背包问题)
  • 用 Python 绘制爱心形状的简单教程
  • 企业百科和品牌百科创建技巧
  • 【CSS】谈谈你对BFC的理解
  • 开源数据分析工具 RapidMiner