当前位置: 首页 > article >正文

torch_bmm验算及代码测试

文章目录

  • 1. torch_bmm
  • 2. pytorch源码

1. torch_bmm

torch.bmm的作用是基于batch_size的矩阵乘法,torch.bmm的作用是对应batch位置的矩阵相乘,比如,

  • mat1的第1个位置和mat2的第1个位置进行矩阵相乘得到mat3的第1个位置
  • mat1的第2个位置和mat2的第2个位置进行矩阵相乘得到mat3的第2个位置
    在这里插入图片描述

2. pytorch源码

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_printoptions(precision=3, sci_mode=False)

if __name__ == "__main__":
    run_code = 0
    batch_size = 2
    mat1_h = 3
    mat1_w = 4
    mat1_total = batch_size * mat1_w * mat1_h
    mat2_h = 4
    mat2_w = 5
    mat2_total = batch_size * mat2_w * mat2_h
    mat1 = torch.arange(mat1_total).reshape((batch_size, mat1_h, mat1_w))
    mat2 = torch.arange(mat2_total).reshape((batch_size, mat2_h, mat2_w))
    mat3 = torch.bmm(mat1, mat2)
    print(f"mat1=\n{mat1}")
    print(f"mat2=\n{mat2}")
    print(f"mat3=\n{mat3}")
  • 结果:
mat1=
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
mat2=
tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9],
         [10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]],

        [[20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34],
         [35, 36, 37, 38, 39]]])
mat3=
tensor([[[  70,   76,   82,   88,   94],
         [ 190,  212,  234,  256,  278],
         [ 310,  348,  386,  424,  462]],

        [[1510, 1564, 1618, 1672, 1726],
         [1950, 2020, 2090, 2160, 2230],
         [2390, 2476, 2562, 2648, 2734]]])

http://www.kler.cn/a/540565.html

相关文章:

  • 【文档智能多模态】英伟达ECLAIR-端到端的文档布局提取,并集成阅读顺序方法
  • 蓝桥杯---数青蛙(leetcode第1419题)
  • 【网络安全】服务器安装Docker及拉取镜像教程
  • 只需三步!5分钟本地部署deep seek——MAC环境
  • 音频进阶学习十一——离散傅里叶级数DFS
  • 【03】 区块链分布式网络
  • 38.社区信息管理系统(基于springboothtml)
  • windows10 wsa 安卓子系统终结版
  • 网络协议课程笔记上
  • AUTOSAR 4.2.2版本中Dem 操作循环(Operation Cycle)的开启和关闭
  • Python实现GO鹅优化算法优化支持向量机SVM回归模型项目实战
  • JSON是什么
  • 从零到一:基于Rook构建云原生Ceph存储的全面指南(上)
  • The 2024 ICPC Asia East Continent Online Contest (II) (6/9/12)
  • JDK8 stream API用法汇总
  • STM32 RTC亚秒
  • 【高级架构师】多线程和高并发编程(三):锁(下)深入ReentrantReadWriteLock
  • Python——批量图片转PDF(GUI版本)
  • 2.10寒假作业
  • 反射:获取类中的成分、并对其进行操作
  • SpringCloud - Sentinel服务保护
  • 矩阵NFC碰一碰发视频的源码技术开发攻略,支持OEM
  • 【数据】Cassandra(列存储)
  • 小红书爬虫: 获取所需数据
  • JVM栈帧中|局部变量表、操作数栈、动态链接各自的任务是什么?
  • Java_多线程