当前位置: 首页 > article >正文

ONNX 转 TensorRT Bug 记录:IIfConditionalOutputLayer

1. 问题描述

环境:TensorRT-8.6.1.6、CUDA-11.8

报错:Error[4]: /If_OutputLayer: IIfConditionalOutputLayer inputs must have the same shape. Shapes are [-1,384] and [-1,1,384].

复现代码:

import os
import torch
import torch.nn as nn
import torch.nn.functional as F


class TestModel(torch.nn.Module):
    def __init__(self, mode):
        super().__init__()
        self.mode = mode
        self.conv = nn.Conv1d(512, 512, 3, 2, 1)

    def forward(self, x, mask):
        if self.mode == 1:
            return self.forward1(x, mask)
        elif self.mode == 2:
            return self.forward2(x, mask)
        elif self.mode == 3:
            return self.forward3(x, mask)
        elif self.mode == 4:
            return self.forward4(x, mask)
        else:
            raise ValueError("Invalid mode")

    def forward1(self, x, mask):
        mask = mask.unsqueeze(1)
        x = self.conv(x)
        mask = F.interpolate(mask, size=x.size(-1), mode="nearest")
        x = x * mask
        mask = mask.squeeze(1)
        return x, mask

    def forward2(self, x, mask):
        mask = mask.unsqueeze(1)
        x = self.conv(x)
        mask = F.interpolate(mask, size=384, mode="nearest")
        x = x * mask
        mask = mask.squeeze(1)
        return x, mask

    def forward3(self, x, mask):
        mask = mask.unsqueeze(1)
        x = x * mask
        mask = mask.squeeze(1)
        return x, mask

    def forward4(self, x, mask):
        mask = mask.unsqueeze(1)
        x = self.conv(x)
        mask = F.interpolate(mask, size=x.size(-1), mode="nearest")
        x = x * mask
        b = x.shape[0]
        mask = mask.reshape(b, -1)
        return x, mask


fake_input = torch.randn(1, 512, 768)
fake_mask = torch.randn(1, 768)
model1 = TestModel(1)
model2 = TestModel(2)
model3 = TestModel(3)
model4 = TestModel(4)

with torch.no_grad():
    print([x.shape for x in model1(fake_input, fake_mask)])
    print([x.shape for x in model2(fake_input, fake_mask)])
    print([x.shape for x in model3(fake_input, fake_mask)])
    print([x.shape for x in model4(fake_input, fake_mask)])


dynamic = {
    "x_input": {0: "batch"},
    "masks_input": {0: "batch"},
    "x_output": {0: "batch"},
    "masks_output": {0: "batch"}
}

save_dir = "log"
os.makedirs(save_dir, exist_ok=True)

for i, model in enumerate((model1, model2, model3, model4)):
    torch.onnx.export(
        model.cpu().eval(),
        (fake_input.cpu(), fake_mask.cpu()),
        os.path.join(save_dir, f'dynamic_{i+1}.onnx'),
        verbose=True,
        opset_version=17,
        do_constant_folding=True,
        input_names=["x_input", "masks_input"],
        output_names=["x_output", "masks_output"],
        dynamic_axes=dynamic
    )

    torch.onnx.export(
        model.cpu().eval(),
        (fake_input.cpu(), fake_mask.cpu()),
        os.path.join(save_dir, f'static_{i+1}.onnx'),
        verbose=True,
        opset_version=17,
        do_constant_folding=True,
        input_names=["x_input", "masks_input"],
        output_names=["x_output", "masks_output"],
        dynamic_axes=None
    )
exec="/home/sfy/SFY/camera/TensorRT-8.6.1.6/bin/trtexec"
dir="log"
rm "$dir"/*.txt "$dir"/*.plan

: <<'EOF'
# onnx-simplifier 移除冗余节点
for file in "$dir"/*.onnx; do
  if [[ ! -f "$file" ]]; then
    echo "No .onnx files found in $dir."
    continue
  fi
  fn=$(basename "$file")
  python -m onnxsim "$file" "$dir/${fn%.onnx}_simp.onnx"
done
EOF

for file in "$dir"/*.onnx; do
  if [[ ! -f "$file" ]]; then
    echo "No .onnx files found in $dir."
    continue
  fi

  fn=$(basename "$file")
  if [[ $fn == dynamic* ]]; then
    $exec \
      --onnx="$file" \
      --saveEngine="$dir/${fn%.onnx}.plan" \
      --minShapes=x_input:1x512x768,masks_input:1x768 \
      --optShapes=x_input:2x512x768,masks_input:2x768 \
      --maxShapes=x_input:4x512x768,masks_input:4x768 \
      --verbose \
      > "$dir/${fn%.onnx}.txt" 2>&1

  elif [[ $fn == static* ]]; then
    $exec \
      --onnx="$file" \
      --saveEngine="$dir/${fn%.onnx}.plan" \
      --verbose \
      > "$dir/${fn%.onnx}.txt" 2>&1
  else
    continue
  fi
done

  上述代码对源代码做了简化,仅剥离出造成问题的部分,用于复现和测试 Bug,输入可以看作 1 个 Batch 有 768 张图像序列,每个图像用 512 维特征向量表示;masks 原本是 bool 类型代表 768 张图像是真实数据还是填充数据。

  经过测试发现 Bug 由动态模式、F.interpolatesqueeze 组合引发,因此代码推理阶段分为以下 4 种模式:
(1)F.interpolate(mask, size=x.size(-1), mode="nearest") + mask.squeeze(1) 报错
(2)F.interpolate(mask, size=384, mode="nearest") + mask.squeeze(1) 动态模式报错,静态模式通过
(3)mask.squeeze(1) 通过
(4)F.interpolate(mask, size=x.size(-1), mode="nearest") + mask.reshape(b, -1)mask.view(b, -1) 通过

2. 解决方法

  用 reshapeview 代替 squeeze

3. 原因分析

  总结:squeeze 操作需要判断对应维度是否等于 1,而 F.interpolate 改变了张量形状、动态模式引入了维度数值的不确定性,这些使得 onnx 无法确定该维度是否等于 1(即使看上去可以推导出数值),导致添加了 If 节点,而 If 节点在不同状态下分别执行 SqueezeIdentity 导致输出形状不统一。采用 Reshape 操作可以直接规避此问题。

  下面是不同模式 ONNX 的结构图对比,以及分析细节。

mode1

在这里插入图片描述在这里插入图片描述
  mode1 在插值时使用 x.size(-1),在动态模式下需要通过 Shape 等一系列节点来获取维度信息;在静态模式下 x 所有初始维度固定,把 x.size(-1) 看作固定常量。

mode2

在这里插入图片描述在这里插入图片描述
  动态模式下 masks 经过 Resize 维度信息是不确定的,导致需要判断 masks.shape[1] == 1 才能执行 masks.squeeze(1),问题在于当不等于 1 时执行的是 Identity 导致输出形状不一致。

  比较令人费解的是对比 mode1 和 mode2 静态模式,可以发现 Resize 之前的结构完全相同,但是 mode1 在 Resize 之后仍引入了 If 节点导致异常。查到的解释是 squeeze 操作比较保守,mode1 原本 Resize 的维度是动态的依赖 x.size(-1),即使通过推导将动态转变为静态,但仍保留了动态逻辑(If 节点)。
  启用脚本中 onnx-simplifier 移除冗余节点的部分可以去除 mode1 静态模式中的 If 节点,但此方法对动态模式无效。

mode3

在这里插入图片描述在这里插入图片描述
  masks 维度的动态(不确定性)由 Resize 引入。

mode4

在这里插入图片描述在这里插入图片描述
  不使用 squeeze 便不需要判断维度是否符合要求,直接规避此问题。


http://www.kler.cn/a/451000.html

相关文章:

  • H3C交换机远程登录基本配置
  • 修改el-select下拉框高度;更新:支持动态修改
  • golangci-lint安装与Goland集成
  • 精准提升:从94.5%到99.4%——目标检测调优全纪录
  • springboot项目对数据库密码加解密
  • ECharts散点图-气泡图,附视频讲解与代码下载
  • 鸿蒙-什么是ArkTS
  • 【C++】模板与泛型编程(一):定义模板,类模板
  • vue3 + MapTalks实现2.5D地图的绘制
  • SQL Server数据库多主模式解决方案
  • 面试小札:Java后端闪电五连鞭_11
  • prometheus监控windows主机
  • Springboot基于Web的高校志愿者服务管理系统81559
  • Git安装及基础学习
  • Blazor 中调用 JavaScript
  • 20241224在ubuntu20.04.6下的终端分屏软件terminator的安装以及使用
  • 网络安全词云图与技术浅谈
  • deepin 安装 zookeeper
  • Git:查看分支、创建分支、合并分支
  • 【漫话机器学习系列】020.正则化强度的倒数C(Inverse of regularization strength)
  • 【CAE开发SDK】CEETRON Envision:适用于桌面端、Web端的数据可视化与分析
  • 【蓝桥杯每日一题】分糖果——DFS
  • Ftp目录整个下载
  • 如何保护你的 iOS 应用免受逆向工程攻击
  • 明厨亮灶系统
  • C++简明教程(9)(多文件编程)