当前位置: 首页 > article >正文

【深度学习模型移植】用torch普通算子组合替代torch.einsum方法

     首先不得不佩服大模型的强大之处,在算法移植过程中遇到einsum算子在ONNX中不支持,因此需要使用普通算子替代。参考TensorRT - 使用torch普通算子组合替代torch.einsum爱因斯坦求和约定算子的一般性方法。可以写出简单的替换方法,但是该方法会导致训练时还是推理都很慢,并且会消耗大量显存,造成显存溢出的问题。。因此采用提问文心一言,没想到居然真的回答正确了。当然替换需要验证,不是全对的。
1.einsum(delta, A, ‘b l d_in, d_in n -> b l d_in n’) 的替换,以下两个方法均可以

deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
deltaA = torch.exp(delta.unsqueeze(dim=3)*A.unsqueeze(dim=0).unsqueeze(dim=0))
deltaA = torch.exp(delta.unsqueeze(-1).repeat_interleave(A.shape[1], dim=-1) * A)

2.einsum(x, C[:, i, :], ‘b d_in n, b n -> b d_in’),以下两个方法均可以

    
    y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
    y = (x*C[:, i, :].unsqueeze(dim=1)).sum(dim=2)
    y = torch.matmul(C[:, i, :], x.transpose(-1, -2)).squeeze(1)

3.einsum(delta, B, u, ‘b l d_in, b l n, b l d_in -> b l d_in n’),以下两个方法均可以

deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
deltaB_u1 = delta.unsqueeze(dim=3)*B.unsqueeze(dim=2)*u.unsqueeze(dim=3)

下述方法是提问文心一言的办法,注意需要将答案的结果和einsum的结果进行对比,采用np.testing.assert_allclose(deltaB_u.numpy(),deltaB_u1.numpy(),rtol=1e-05,atol=1e-05)和print(deltaA.equal(deltaA_manual))均可以。

import torch
import numpy as np
from einops import rearrange, repeat, einsum
# 给定的张量
delta = torch.ones([1, 3, 2])
A = torch.ones([2, 4])
deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
deltaA1 = torch.exp(delta.unsqueeze(dim=3)*A.unsqueeze(dim=0).unsqueeze(dim=0))
deltaA_manual = torch.exp(delta.unsqueeze(-1).repeat_interleave(A.shape[1], dim=-1) * A)
np.testing.assert_allclose(deltaA.numpy(),deltaA1.numpy(),rtol=1e-05,atol=1e-05)

# 扩展 delta 的维度,以便它可以与 A 进行广播(broadcast)
# 这里我们使用 unsqueeze 和 repeat_interleave 来扩展维度
delta_expanded = delta.unsqueeze(-1).repeat_interleave(A.shape[1], dim=-1)
# 执行逐元素的乘法,然后取指数
deltaA_manual = torch.exp(delta_expanded * A)

# 注意:deltaA_manual 的形状是 [1, 3, 2, 4],这与 einsum 的输出形状一致
print(deltaA.equal(deltaA_manual))
print(deltaA1.equal(deltaA_manual))

请添加图片描述
请添加图片描述
请添加图片描述


http://www.kler.cn/a/271824.html

相关文章:

  • react-native网络调试工具Reactotron保姆级教程
  • 【计算机网络】host文件
  • leetcode刷题记录(一百)——121. 买卖股票的最佳时机
  • 15天基础内容-5
  • 验证二叉搜索树(力扣98)
  • erase() 【删数函数】的使用
  • 圈子社交系统-多人语音-交友-陪玩-活动报名-商城-二手论坛-源码交付,支持二开!
  • C++ 中的虚函数和多态性
  • docker实战(2)
  • 软考76-上午题-【面向对象技术3-设计模式】-创建型设计模式01
  • Oracle 部署及基础使用
  • Matlab/simulink基于模糊PID智能控制的温度控制系统建模仿真
  • 一起玩儿3D打印机——04 Marlin固件的配置(一)
  • Docker出现容器名称重复如何解决
  • 驾驭Docker镜像海洋:Nexus一站式仓库管理解决方案深度解析及实战指南
  • Hadoop学习3:问题解决
  • 【PyQt错误集 - 1】:PyQt调用多线程导致窗口异常退出的问题分析(进程已结束,退出代码 -1073741819 (0xC0000005))
  • 小蓝的漆房——算法思路
  • Blocks —— 《Objective-C高级编程 iOS与OS X多线程和内存管理》
  • 通过对话式人工智能实现个性化用户体验
  • 论文阅读——GeoChat(cvpr2024)
  • Linux运维相关基础知识
  • 030—pandas 对数据透视并将多层索引整合为一列
  • Sass学习记录
  • 有参转录组分析 |基因组信息下载和FQ数据过滤教程
  • ts版本微信小程序在wxml保存文件不刷新页面的解决办法