当前位置: 首页 > article >正文

深度学习系列72:torch-tensorrt入门

1. 安装

坑非常多,清华源阿里源都不行。使用官网源下载,这里的121可以改成你需要的东西:
python -m pip install torch torch-tensorrt tensorrt --extra-index-url https://download.pytorch.org/whl/cu121

2. 原理

我们来看一个实例:这是一个用于支持 torchscript 到 TensorRT 转换的项目。上面的代码用于将 addmm 运算展开成数个算子,方便后续映射 TensorRT 算子。

void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
  // TensorRT implicitly adds a flatten layer in front of FC layers if necessary
  // 用于匹配的模式
  std::string addmm_pattern = R"IR(
    graph(%b, %x, %w, %beta, %alpha):
      %out: Tensor = aten::addmm(%b, %x, %w, %beta, %alpha)
      return (%out))IR";
  // 用于替换的模式
  std::string mm_add_pattern = R"IR(
    graph(%b, %x, %w, %beta, %alpha):
      %mm: Tensor = aten::matmul(%x, %w)
      %bias: Tensor = aten::mul(%b, %beta)
      %out: Tensor = aten::add(%bias, %mm, %alpha)
      return (%out))IR";

  // 创建子图重写器并注册匹配模式和替换模式
  torch::jit::SubgraphRewriter unpack_addmm;
  unpack_addmm.RegisterRewritePattern(addmm_pattern, mm_add_pattern);
  // 遍历graph,完成重写
  unpack_addmm.runOnGraph(graph);
  LOG_GRAPH("Post unpack addmm: " << *graph);
}

3. 简单例子

import torch
def origin_func(x):
    x = x**2
    x = x**3
    return x

x = torch.rand(1, 2, 3, 4)
jit_model = torch.jit.trace(origin_func, x)
print(jit_model.graph)

# 匹配用的子图定义,注意常量必须为[value=2]属性
pattern = """
    graph(%x):
        %const_2 = prim::Constant[value=2]()
        %out = aten::pow(%x, %const_2)
        return (%out)
"""
# 替换用的子图定义
replacement = """
    graph(%x):
        %out = aten::mul(%x, %x)
        return (%out)
"""
torch._C._jit_pass_custom_pattern_based_rewrite_graph(pattern, replacement,jit_model.graph)
print(jit_model.graph)

http://www.kler.cn/a/287144.html

相关文章:

  • 初学人工智不理解的名词3
  • uniapp 自定义加载组件,全屏加载,局部加载 (微信小程序)
  • windows系统中实现对于appium的依赖搭建
  • 植物明星大乱斗15
  • 《图神经网络:简介》
  • Pytest 学习 @allure.severity 标记用例级别的使用
  • uniapp 生成H5 返回上一页 事件不执行
  • Python入门案例01
  • 20240829软考架构-------软考91-95答案解析
  • 视联动力数字科技新成果闪耀2024数博会
  • 科普小课堂:中等硬度的床垫,合适的睡姿,通过日常力量练习提升自身能力以支撑脊柱形态。
  • 【drools】intelj修改JDK版本、进行maven test
  • 业务资源管理模式语言04
  • 【Python-办公自动化】批量修改EXCEL指定内容
  • 牛客周赛 Round 57(A,B,C,D,E,F,G)
  • @Tanstack/vue-query 的使用介绍
  • jQuery基础——开发插件
  • template<typename ... _Args>可变参数模板
  • LiveQing视频点播流媒体RTMP推流服务用户手册-分屏展示:单分屏、四分屏、九分屏、十六分屏、轮巡播放、分组管理、记录加载
  • 001集——CAD—C#二次开发入门——开发环境基本设置
  • 超详细步骤——Keil MDK-ARM 如何修改工程名字
  • 外媒:《黑神话》成功后 中国加大对游戏行业的关注度
  • 触摸传感器的工作原理
  • Windows TCP/IP IPv6 DDos远程蓝屏复现及修复(CVE-2024-38063)
  • MFC生成dll的区别
  • Linux2-Linux基础命令