深度学习中TorchScript原理、作用浅析(Trace/Script)
之前深度学习模型训练后训练的模型想要在C++中调用时,有一种简单的方法是通过Libtorch,但是前提条件是需要将模型转换为TorchScript格式。之前并没有太多理解TorchScript,最近任务有遇到今天浅浅学习一下
TorchScript 简介
TorchScript 是 PyTorch 提供的一种 序列化和优化模型 的方式,使得 PyTorch 模型可以在 无 Python 依赖 的环境中运行,如 C++ 端或移动设备上。它可以将 动态图(即 PyTorch 的 nn.Module
代码)转换为 静态计算图,从而提高推理效率,并支持跨平台部署。
TorchScript 的两种方式
Torch.jit.trace
- 适用于: 计算图是 固定 的,没有 if 语句、循环等动态控制流。
- 方法: 直接传入一个输入样本,TorchScript 记录运算流程,并生成静态计算图。
示例:
# 创建模型并进行 trace
model = SimpleModel()
example_input = torch.randn(1, 3, 640, 640)
traced_model = torch.jit.trace(model, example_input)
# 保存和加载
traced_model.save("traced_model.pt")
loaded_model = torch.jit.load("traced_model.pt")
Torch.jit.script
- 适用于: 计算图包含动态控制流(如 if、for 等),或者 tracing 不适用的情况。
- 方法: 直接使用 torch.jit.script 进行转换。
# 直接使用 scripting 进行转换
scripted_model = torch.jit.script(DynamicModel())
# 保存和加载
scripted_model.save("scripted_model.pt")
loaded_model = torch.jit.load("scripted_model.pt")
以个人经验通常是torch.jit.trace使用不了的情况下使用torch.jit.script;单图像模型用torch.jit.trace,多帧图像模型用torch.jit.scrript
在 TorchScript 转换过程中,如果 model
包含 DataParallel
、使用 numpy
、调用 OpenCV
,都会导致 无法正确转换为 TorchScript。这是因为 TorchScript 需要 静态计算图,而这些情况可能导致 不符合静态计算图要求。下面详细解析为什么会发生这些问题,以及如何解决。
- DataParallel
DataParallel
不是nn.Module
的直接子类,而是一个 封装器,它会在多个 GPU 上 复制 子模型,并在前向传播时 拆分输入并收集输出。torch.jit.script()
不支持 这种封装方式,因为 TorchScript 期望一个 单设备、单计算图 的模型。- 可以遍历让model=model.module,当然这里较为灵活,很多模型名称不一样,总之可以打印模型结构读取去掉多余的部分即可
- 通常情况下不会遇见numpy、opencv等问题。
model的输入输出应该是Union[Tensor, Tuple[Tensor], Dict[str, Tensor]]
的类型,而且在dict中的值,应该是同样的类型。但是对于model中间子模块的输入输出,可以是任意类型,例如dicts of Any, classes, kwargs以及python支持的都可以。对于model输入输出类型的限制是比较容易满足的,在yolov8中,有类似的例子:
class BaseModel(nn.Module):
"""The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family."""
def forward(self, x, *args, **kwargs):
"""
Forward pass of the model on a single scale. Wrapper for `_forward_once` method.
Args:
x (torch.Tensor | dict): The input image tensor or a dict including image tensor and gt labels.
Returns:
(torch.Tensor): The output of the network.
"""
if isinstance(x, dict): # for cases of training and validating while training.
return self.loss(x, *args, **kwargs)
return self.predict(x, *args, **kwargs)
这也是为什么Yolov8可以支持torch.jit.script而不支持trace的一个原因。
参考
还有一些分析并没有遇到过,所以就先记录到这里,一些资料参考如下
TorchScript:跟踪与脚本 - Yuxin 的博客
torch.jit.trace与torch.jit.script - 知乎