Pytorch使用手册—自定义 C++ 和 CUDA 运算符(专题五十一)
你将学到什么
- 如何将用 C++/CUDA 编写的自定义运算符与 PyTorch 集成
- 如何使用
torch.library.opcheck
测试自定义运算符
先决条件 1. PyTorch 2.4 或更高版本 2. 对 C++ 和 CUDA 编程有基本了解
注意
本教程也适用于 AMD ROCm,无需额外修改。
PyTorch 提供了一个庞大的运算符库,这些运算符可以对张量进行操作(例如 torch.add
、torch.sum
等)。然而,您可能希望向 PyTorch 引入一个新的自定义运算符。本教程演示了如何以推荐的方式编写用 C++/CUDA 实现的自定义运算符。
在本教程中,我们将演示如何编写一个与 PyTorch 子系统结合的融合乘加(fused multiply-add)C++ 和 CUDA 运算符。该操作的语义如下:
def <