pytorch自定义算子导出onnx
文章目录
- 1、为什么要自定义算子?
- 2、如何自定义算子
- 3、自定义算子导出onnx
- 4、example
- 1、重写一个pytorch 自定义算子(实现自定义激活函数)
- 2、现有算子上封装pytorch 自定义算子(实现动态放大超分辨率模型)
1、为什么要自定义算子?
1、没有现成可用的算子,需要根据自己的接口重写。
2、现有的算子接口不兼容,需要在原有的算子上进行封装。
2、如何自定义算子
继承torch.autograd.Function类,实现其forward() 和 backward()方法,就可以成为一个pytorch自定义算子。就可以在模型训练推理中完成前向推理和反向传播。
forward() 函数的第一个参数必须是ctx, 后面是输入。
在工程部署上,一般为了加快计算,自定义算子需要用cuda 实现forward()、backward()kernel 函数。
3、自定义算子导出onnx
实现其symbolic 静态方法,当我们调用torch.onnx.export()时,就可以导出onnx 算子。
symbolic是符号函数,通常在其内部返回一个g.op()对象。g.op() 把一个 PyTorch 算子映射成一个或多个 ONNX 算子,或者是自定义的 ONNX 算子。
symbolic函数的第一个参数必须是g, 后面是和forward()对应的输入。
g.op() 做算子映射,g.op 的参数:
1、第一个参数为算子名字,
2、后面参数与forward() 输入对应,
3、往后可以是一些算子自带常量和属性值。常量视为输入,属性值需要用 字段_s/i/f = 默认值表示。_s 表示字符串,_i 表示 int64, _f 表示 float32。常量用类似 g.op(“Constant”, value_t=torch.tensor([3, 2, 1], dtype=torch.float32))表示
4、example
1、重写一个pytorch 自定义算子(实现自定义激活函数)
实现自己的激活函数MYSELU 算子。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.onnx
import torch.autograd
#继承torch.autograd.Function
class MYSELUImpl(torch.autograd.Function):
@staticmethod
def symbolic(g, x, p):
return g.op("MYSELU", x, p, # 表示onnx算子的名称为MYSELU,参数与forward()对应
# 给算子传一个常数参数
g.op("Constant", value_t=torch.tensor([3, 2, 1], dtype=torch.float32)),
attr1_s="这是字符串属性", # s表示字符串
attr2_i=[1, 2, 3], # i表示整数
attr3_f=222 # f表示浮点数
)
@staticmethod
def forward(ctx, x, p): # 前行推理
return x * 1 / (1 + torch.exp(-x))
class MYSELU(nn.Module):
def __init__(self, n):
super().__init__()
self.param = nn.parameter.Parameter(torch.arange(n).float())
def forward(self, x):
return MYSELUImpl.apply(x, self.param) #推理调用
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 3, padding=1)
self.myselu = MYSELU(3)
self.conv.weight.data.fill_(1)
self.conv.bias.data.fill_(0)
def forward(self, x):
x = self.conv(x)
x = self.myselu(x)
return x
2、现有算子上封装pytorch 自定义算子(实现动态放大超分辨率模型)
实现动态放大超分辨率模型。我们希望实现:
forward(self, x, upscale_factor)
这样一个接口,x 为图像输入,upscale_factor为动态放大倍数。
pytorch 现有放大算子有nn.Upsample 和 interpolate, 但是nn.Upsample 在初始化阶段固化了放大倍数,而 PyTorch 的 interpolate 插值算子可以在运行阶段选择放大倍数。
class SuperResolutionNet(nn.Module):
def forward(self, x, upscale_factor):
x = interpolate(x,
scale_factor=upscale_factor.item(),
mode='bicubic',
align_corners=False)
...
# Inference
# Note that the second input is torch.tensor(3)
torch_output = model(torch.from_numpy(input_img), torch.tensor(3)).detach().numpy()
...
with torch.no_grad():
torch.onnx.export(model, (x, torch.tensor(3)),
"srcnn2.onnx",
opset_version=11,
input_names=['input', 'factor'],
output_names=['output'])
尝试使用以上方法导出onnx 时,虽然没有报错能成功导出onnx,但是有TraceWarning 的警告,说明导出onnx有追踪失败。这是由于我们使用了 torch.Tensor.item() 把数据从 Tensor 里取出来,而导出 ONNX 模型时这个操作是无法被记录的,只好报了一条 TraceWarning。
因此我们需要自定义算子,让onnx在追踪时刻能work。我们看到nn.Upsample 和 interpolate在转onnx时都映射到了Resize 操作。所以自定义算子在Resize 操作上进行封装即可。
Resize 操作有三个输入,x, roi, scale, 我们就是要动态输入scale。展开 scales,可以看到 scales 是一个长度为 4 的一维张量,其内容为 [1, 1, 3, 3],
如果我们能够自己生成一个 ONNX 的 Resize 算子,让 scales 成为一个可变量而不是常量,就像它上面的 X 一样,那这个超分辨率模型就能动态缩放了。
import torch
from torch import nn
from torch.nn.functional import interpolate
import torch.onnx
import cv2
import numpy as np
class NewInterpolate(torch.autograd.Function):
@staticmethod
def symbolic(g, input, scales):
return g.op("Resize",
input,
g.op("Constant",
value_t=torch.tensor([], dtype=torch.float32)),
scales,
coordinate_transformation_mode_s="pytorch_half_pixel",
cubic_coeff_a_f=-0.75,
mode_s='cubic',
nearest_mode_s="floor")
@staticmethod
def forward(ctx, input, scales):
scales = scales.tolist()[-2:]
return interpolate(input,
scale_factor=scales,
mode='bicubic',
align_corners=False)
class StrangeSuperResolutionNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0)
self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
self.relu = nn.ReLU()
def forward(self, x, upscale_factor):
x = NewInterpolate.apply(x, upscale_factor)
out = self.relu(self.conv1(x))
out = self.relu(self.conv2(out))
out = self.conv3(out)
return out
以上自定义了Resize 算子,将scale 作为算子的一个输入,最后还是调用interpolate。但是scale已经变成自定义输入参数。
参数映射如下: