当前位置: 首页 > article >正文

深度学习中TorchScript原理、作用浅析(Trace/Script)

之前深度学习模型训练后训练的模型想要在C++中调用时,有一种简单的方法是通过Libtorch,但是前提条件是需要将模型转换为TorchScript格式。之前并没有太多理解TorchScript,最近任务有遇到今天浅浅学习一下

TorchScript 简介

TorchScript 是 PyTorch 提供的一种 序列化和优化模型 的方式,使得 PyTorch 模型可以在 无 Python 依赖 的环境中运行,如 C++ 端或移动设备上。它可以将 动态图(即 PyTorch 的 nn.Module 代码)转换为 静态计算图,从而提高推理效率,并支持跨平台部署。

TorchScript 的两种方式

Torch.jit.trace

  • 适用于: 计算图是 固定 的,没有 if 语句、循环等动态控制流。
  • 方法: 直接传入一个输入样本,TorchScript 记录运算流程,并生成静态计算图。
    示例:
# 创建模型并进行 trace
model = SimpleModel()
example_input = torch.randn(1, 3, 640, 640)
traced_model = torch.jit.trace(model, example_input)

# 保存和加载
traced_model.save("traced_model.pt")
loaded_model = torch.jit.load("traced_model.pt")

Torch.jit.script

  • 适用于: 计算图包含动态控制流(如 if、for 等),或者 tracing 不适用的情况。
  • 方法: 直接使用 torch.jit.script 进行转换。
# 直接使用 scripting 进行转换
scripted_model = torch.jit.script(DynamicModel())

# 保存和加载
scripted_model.save("scripted_model.pt")
loaded_model = torch.jit.load("scripted_model.pt")

以个人经验通常是torch.jit.trace使用不了的情况下使用torch.jit.script;单图像模型用torch.jit.trace,多帧图像模型用torch.jit.scrript

在 TorchScript 转换过程中,如果 model 包含 DataParallel、使用 numpy、调用 OpenCV,都会导致 无法正确转换为 TorchScript。这是因为 TorchScript 需要 静态计算图,而这些情况可能导致 不符合静态计算图要求。下面详细解析为什么会发生这些问题,以及如何解决。

  1. DataParallel
  • DataParallel 不是 nn.Module 的直接子类,而是一个 封装器,它会在多个 GPU 上 复制 子模型,并在前向传播时 拆分输入并收集输出
  • torch.jit.script() 不支持 这种封装方式,因为 TorchScript 期望一个 单设备、单计算图 的模型。
  • 可以遍历让model=model.module,当然这里较为灵活,很多模型名称不一样,总之可以打印模型结构读取去掉多余的部分即可
  1. 通常情况下不会遇见numpy、opencv等问题。

model的输入输出应该是Union[Tensor, Tuple[Tensor], Dict[str, Tensor]]的类型,而且在dict中的值,应该是同样的类型。但是对于model中间子模块的输入输出,可以是任意类型,例如dicts of Any, classes, kwargs以及python支持的都可以。对于model输入输出类型的限制是比较容易满足的,在yolov8中,有类似的例子:


class BaseModel(nn.Module):
    """The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family."""

    def forward(self, x, *args, **kwargs):
        """
        Forward pass of the model on a single scale. Wrapper for `_forward_once` method.

        Args:
            x (torch.Tensor | dict): The input image tensor or a dict including image tensor and gt labels.

        Returns:
            (torch.Tensor): The output of the network.
        """
        if isinstance(x, dict):  # for cases of training and validating while training.
            return self.loss(x, *args, **kwargs)
        return self.predict(x, *args, **kwargs)

这也是为什么Yolov8可以支持torch.jit.script而不支持trace的一个原因。

参考

还有一些分析并没有遇到过,所以就先记录到这里,一些资料参考如下

TorchScript:跟踪与脚本 - Yuxin 的博客

torch.jit.trace与torch.jit.script - 知乎


http://www.kler.cn/a/574053.html

相关文章:

  • MySQL事务,函数,性能,索引
  • 【GoTeams】-2:项目基础搭建(下)
  • BGP服务器主要是指什么?
  • 硬件学习【1】74HC165D-并行信号输入-串行输出
  • VSCode 配置优化指南:打造极致高效的前端开发环境
  • 系统架构设计师—软件工程基础篇—软件开发方法
  • 【无标题】养老护理初级考题抽取——2大题抽1+7小题抽2-共有432种可能。
  • 【LeetCode 热题 100】438. 找到字符串中所有字母异位词 | python 【中等】
  • go语言因为前端跨域导致无法访问到后端解决方案
  • vim基本操作及常用命令
  • WPS条件格式:B列的值大于800,并且E列的值大于B列乘以0.4时,这一行的背景标红
  • 蓝桥与力扣刷题(蓝桥 数字三角形)
  • AT89S51 单片机手册解读:架构、功能与应用深度剖析
  • R语言——数据类型
  • 单例模式的五种实现方式
  • MATLAB中lookAheadBoundary函数用法
  • 【前端基础】Day 6 CSS定位
  • 护照阅读器在汽车客运站流程中的应用
  • 洛谷P1102 A-B 数对
  • OceanBase-obcp-v3考试资料梳理