当前位置: 首页 > article >正文

pytorch torch.einsum函数介绍

torch.einsum 是 PyTorch 中一个强大且灵活的张量运算函数,基于爱因斯坦求和约定进行操作。它允许用户通过简单的字符串表达式来定义复杂的张量运算,代替显式的循环或多个矩阵乘法操作。

函数签名

torch.einsum(equation, *operands) → Tensor

参数

  • equation: 一个字符串,描述了张量间的操作关系。它使用爱因斯坦求和约定,用逗号分隔不同张量的索引,使用箭头(->)定义输出的形状。
    • 左侧部分是输入张量的维度索引,逗号分隔。
    • 右侧是输出张量的维度索引。如果没有提供输出维度,函数默认对所有不重复的索引进行求和。
  • *operands: 需要操作的张量。张量的维度必须与 equation 中的描述匹配。

爱因斯坦求和约定

爱因斯坦求和约定是一种简化张量运算的表示方式。它假设对所有重复的索引进行求和。例如:

  • 'ij,jk->ik' 表示矩阵乘法,其中 i 和 k 是保留下来的维度,j 是求和的维度。

示例

1. 矩阵乘法

矩阵乘法可以通过 torch.einsum 实现:

import torch

A = torch.tensor([[1, 2],
                  [3, 4]])

B = torch.tensor([[5, 6],
                  [7, 8]])

# 等同于 torch.matmul(A, B)
result = torch.einsum('ij,jk->ik', A, B)

print(result)

在这里:

  • A 的维度索引为 ij(行和列)。
  • B 的维度索引为 jk
  • 输出的维度索引为 ik,表示保留 A 的行和 B 的列。
2. 向量内积
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])

# 等同于 torch.dot(x, y)
result = torch.einsum('i,i->', x, y)

print(result)

这里的 'i,i->' 表示对 i 这个索引求和,即两个向量的内积。

3. 批次矩阵乘法

对批次的矩阵执行乘法:

A = torch.randn(3, 2, 4)  # 3 个 2x4 矩阵
B = torch.randn(3, 4, 5)  # 3 个 4x5 矩阵

# 执行批次矩阵乘法
result = torch.einsum('bij,bjk->bik', A, B)

print(result.shape)  # 输出 (3, 2, 5)
  • bij 表示 A 的维度,其中 b 是批次维度,i 是行,j 是列。
  • bjk 表示 B 的维度,其中 b 是批次维度,j 是行,k 是列。
  • 输出 bik 保留批次 b 维度和矩阵乘法后的行和列。
4. 广播加法
A = torch.randn(3, 4)
B = torch.randn(4)

# 广播加法
result = torch.einsum('ij,j->ij', A, B)

print(result.shape)  # 输出 (3, 4)

这里,B 的维度被广播扩展到与 A 的第二维度匹配。

总结

  • torch.einsum 适合各种线性代数操作,例如矩阵乘法、点积、转置、广播等。
  • 使用 einsum 可以使代码更简洁,并避免繁琐的手动张量操作。
  • 它可以处理多维张量,并通过显式地定义每个维度的映射关系,适用于很多高效的深度学习运算。

torch.einsum 是一个高效、简洁的工具,可以替代多个显式的 torch.matmul 或 torch.bmm 等函数,使代码更加直观和灵活。


http://www.kler.cn/a/298047.html

相关文章:

  • vue缓存用法
  • 信息学奥赛初赛天天练-87-NOIP2014普及组-完善程序-矩阵、子矩阵、最大子矩阵和、前缀和、打擂台求最大值
  • Lombok失效:报错 找不到符号 Springboot项目
  • 【加密社】3分钟快速制作一个爬虫?不懂编程也没关系
  • YOLOv8改进 | 模块缝合 | C2f 融合RFAConv和CBAM注意力机制 【二次融合 小白必备】
  • 【SAP-ABAP】JAVA通过SAP JCO(SAP.JAR)链接SAP需要注意哪些事项(SAP ROUTER连接报错)
  • 机器学习:对数据进行降维(PCA和SVD)
  • 逻辑一键导入导出,解决企业多环境数据迁移的难题
  • 【PyTorch】使用容器(Containers)进行网络层管理(Module)
  • (六十五)第 10 章 内部排序(希尔排序)
  • 三十五、Gin注册功能实战
  • VScode 的简单使用
  • 进程
  • TortoiseGit无法安装解决方案
  • 16 训练自己语言模型
  • javase笔记12----线程2
  • 深度学习笔记15_TensorFlow实现运动鞋品牌识别
  • Oracle表操作详解使用
  • CentOS 安装Squid代理
  • 用Postman调试是英文导致系统语言变成英文,SQL语句查询不出来对应的字段,出现SAP系统里面调试是有值的,但是外部调用是没有值的!