当前位置: 首页 > article >正文

Pytorch中矩阵乘法使用及案例

六种矩阵乘法

torch中包含许多矩阵乘法,大致可以分为以下几种:

  • *:即a * b 按位相乘,要求ab的形状必须一致,支持广播操作

  • torch.matmul():最广泛的矩阵乘法

  • @:与torch.matmul()效果一样(等价),即torch.matmul(a, b) == a @ b

  • torch.dot():两个一维向量乘法,不支持广播

  • torch.mm():两个二维矩阵的乘法,不支持广播

  • torch.bmm():两个三维矩阵乘法(批次batch粒度),且两个矩阵必须是三维的,不支持广播操作

其中,torch.matmul()中包含torch.dot()torch.mm()torch.bmm()

代码验证

torch.dot()

a = torch.tensor([2, 3])
b = torch.tensor([2, 1])

## 下面四个函数的结果是一样的  结果都是7
a.dot(b)
torch.dot(a, b)
a @ b
torch.matmul(a, b)

输出结果:
在这里插入图片描述

torch.matmul()torch.dot()的主要区别就是,当两个向量(矩阵)的维度不一致时,torch.matmul()会进行广播,而torch.dot()会报错

*

对向量ab进行按位相乘

a = torch.tensor([2, 3])
b = torch.tensor([2, 1])

a * b  # [4, 3]

torch.mm()

用于二维矩阵的相乘——第一个向量的和第二个向量的必须相等

mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)

## 下面三个输出结果是一样的
torch.mm(mat1, mat2)
mat1.matmul(mat2)
mat1 @ mat2

输出结果:
在这里插入图片描述

torch.matmul()torch.mm()的主要区别就是,当两个矩阵的维度不一致时,torch.matmul()会进行广播,而torch.mm()会报错

torch.bmm()

应用于三维矩阵,要求:

  • 两个矩阵的第一个维度的大小必须相同
  • 必须满足第一个矩阵:(b × n × m),第二个矩阵:(b × m × p),即第一个矩阵的第三个维度必须和第二个矩阵的第二个维度相同
  • 输出大小:(b × n × p)

该函数相当于分别对每个batch进行二维矩阵相乘

bmat1 = torch.randn(2, 1, 4)
bmat2 = torch.randn(2, 4, 2)

## 下面三个输出是一样的
torch.bmm(bmat1, bmat2)
bmat1.matmul(bmat2)
bmat1 @ bmat2

输出结果:
在这里插入图片描述

换一种角度想,torch.bmm()就是相当于按照批次batch进行索引,然后将每个批次内的二维矩阵进行相乘

for i in range(bmat1.shape[0]):  # 索引出来批次bmat1.shape[0]
    temp =torch.mm(bmat1[i, :, :], bmat2[i, :, :])
    print(temp)

在这里插入图片描述

原文地址:https://blog.csdn.net/Iawfy22/article/details/146245349
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.kler.cn/a/584312.html

相关文章:

  • Manus(一种AI代理或自动化工具)与DeepSeek(一种强大的语言模型或AI能力)结合使用任务自动化和智能决策
  • 视频理解之Actionclip(论文宏观解读)
  • 小米路由器SSH下安装DDNS-GO
  • 【学习笔记】语言模型的发展历程
  • vue2安装scss
  • Java 断言(Assert)机制
  • 西门子S7-1200 PLC远程调试技术方案(巨控GRM532模块)
  • 【蓝桥杯速成】| 1.暴力解题
  • 音视频入门基础:RTP专题(17)——音频的SDP媒体描述
  • NVIDIA k8s-device-plugin源码分析与安装部署
  • 要登录的设备ip未知时的处理方法
  • 浅谈 JavaScript 对象:属性、方法与创建模式
  • MySql学习_基础Sql语句
  • 记录一个SQL自动执行的html页面
  • Mistral OCR:重新定义文档理解的下一代OCR技术
  • 课堂练习 1:配置虚拟主机
  • TCP协议支持全双工原因TCP发送接收数据是生产者消费者模型
  • JAVA-Thread类实现多线程
  • 【NLP 33、实践 ⑦ 基于Triple Loss作表示型文本匹配】
  • 数字化新零售与 AI 大模型,如何重塑大健康赛道?​