当前位置: 首页 > article >正文

初步认识torch自定义算子

此篇为PyTorch 自定义算子:复现CPU和CUDA版的二维卷积的代码详解
这篇是为了展示setup在构建简单的cpp算子的使用

1.环境配置

整体结构如下图所示
在这里插入图片描述

pytorch_cpp_helper.hpp中准备了CPU版卷积需要的头文件
pytorch_cuda_helper.hpp和common_cuda_helper.hpp是cuda的头文件
my_conv.cpp中是对卷积的实现
my_conv.py是使用卷积搭建的网络
setup.py对cpp文件编译使用
test.py是对整体的测试

2.pytorch_cpp_helper.hpp

#ifndef PYTORCH_CPP_HELPER
#define PYTORCH_CPP_HELPER
// torch/extension.h 提供了如下的几个重要功能:
// Tensor 操作:可以创建、读取、修改和销毁张量。
// 错误处理:通过 TORCH_CHECK 和 TORCH_INTERNAL_ASSERT 等宏来检查条件并抛出异常。
// 自动梯度计算:通过 autograd::Function 和 autograd::Variable 类来实现自定义的自动微分操作。
// CUDA 支持:如果扩展需要使用 GPU 加速,可以通过 THC 和 ATen/CUDA API 来实现。
#include <torch/extension.h>

#include <vector>

using namespace at;

#define CHECK_CUDA(x) \
  TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CPU(x) \
  TORCH_CHECK(!x.device().is_cuda(), #x " must be a CPU tensor")
#define CHECK_CONTIGUOUS(x) \
  TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CUDA_INPUT(x) \
  CHECK_CUDA(x);            \
  CHECK_CONTIGUOUS(x)
#define CHECK_CPU_INPUT(x) \
  CHECK_CPU(x);            \
  CHECK_CONTIGUOUS(x)

#endif  // PYTORCH_CPP_HELPER

3. pytorch_cuda_helper.hpp

#ifndef PYTORCH_CUDA_HELPER
#define PYTORCH_CUDA_HELPER

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <THC/THCAtomics.cuh>

#include "common_cuda_helper.hpp"

using at::Half;
using at::Tensor;
using phalf = at::Half;

#define __PHALF(x) (x)

#endif  // PYTORCH_CUDA_HELPER

4. common_cuda_helper.hpp

#ifndef COMMON_CUDA_HELPER
#define COMMON_CUDA_HELPER

#include <cuda.h>

#define CUDA_1D_KERNEL_LOOP(i, n)                              \
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
       i += blockDim.x * gridDim.x)

#define CUDA_2D_KERNEL_LOOP(i, n, j, m)                             \
  for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n);   \
       i += blockDim.x * gridDim.x)                                 \
    for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); \
         j += blockDim.y * gridDim.y)

#define CUDA_2D_KERNEL_BLOCK_LOOP(i, n, j, m)          \
  for (size_t i = blockIdx.x; i < (n); i += gridDim.x) \
    for (size_t j = blockIdx.y; j < (m); j += gridDim.y)

#define THREADS_PER_BLOCK 512

inline int GET_BLOCKS(const int N, const int num_threads = THREADS_PER_BLOCK) {
  int optimal_block_num = (N + num_threads - 1) / num_threads;
  int max_block_num = 4096;
  return min(optimal_block_num, max_block_num);
}

template <typename T>
__device__ T bilinear_interpolate(const T* input, const int height,
                                  const int width, T y, T x,
                                  const int index /* index for debug only*/) {
  // deal with cases that inverse elements are out of feature map boundary
  if (y < -1.0 || y > height || x < -1.0 || x > width) return 0;

  if (y <= 0) y = 0;
  if (x <= 0) x = 0;

  int y_low = (int)y;
  int x_low = (int)x;
  int y_high;
  int x_high;

  if (y_low >= height - 1) {
    y_high = y_low = height - 1;
    y = (T)y_low;
  } else {
    y_high = y_low + 1;
  }

  if (x_low >= width - 1) {
    x_high = x_low = width - 1;
    x = (T)x_low;
  } else {
    x_high = x_low + 1;
  }

  T ly = y - y_low;
  T lx = x - x_low;
  T hy = 1. - ly, hx = 1. - lx;
  // do bilinear interpolation
  T v1 = input[y_low * width + x_low];
  T v2 = input[y_low * width + x_high];
  T v3 = input[y_high * width + x_low];
  T v4 = input[y_high * width + x_high];
  T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;

  T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);

  return val;
}

template <typename T>
__device__ void bilinear_interpolate_gradient(
    const int height, const int width, T y, T x, T& w1, T& w2, T& w3, T& w4,
    int& x_low, int& x_high, int& y_low, int& y_high,
    const int index /* index for debug only*/) {
  // deal with cases that inverse elements are out of feature map boundary
  if (y < -1.0 || y > height || x < -1.0 || x > width) {
    // empty
    w1 = w2 = w3 = w4 = 0.;
    x_low = x_high = y_low = y_high = -1;
    return;
  }

  if (y <= 0) y = 0;
  if (x <= 0) x = 0;

  y_low = (int)y;
  x_low = (int)x;

  if (y_low >= height - 1) {
    y_high = y_low = height - 1;
    y = (T)y_low;
  } else {
    y_high = y_low + 1;
  }

  if (x_low >= width - 1) {
    x_high = x_low = width - 1;
    x = (T)x_low;
  } else {
    x_high = x_low + 1;
  }

  T ly = y - y_low;
  T lx = x - x_low;
  T hy = 1. - ly, hx = 1. - lx;

  // reference in forward
  // T v1 = input[y_low * width + x_low];
  // T v2 = input[y_low * width + x_high];
  // T v3 = input[y_high * width + x_low];
  // T v4 = input[y_high * width + x_high];
  // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);

  w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;

  return;
}
#endif  // COMMON_CUDA_HELPER

5.my_conv.cpp

#include "pytorch_cpp_helper.hpp"

void my_conv_im2col_cpu(Tensor data_im,
                        const int channels, const int height,
                        const int width, const int ksize_h,
                        const int ksize_w, const int pad_h, const int pad_w,
                        const int stride_h, const int stride_w,
                        const int dilation_h, const int dilation_w,
                        const int parallel_imgs, Tensor data_col);
// cuda实现
void my_conv_im2col_cuda(Tensor data_im,
                        const int channels, const int height,
                        const int width, const int ksize_h,
                        const int ksize_w, const int pad_h, const int pad_w,
                        const int stride_h, const int stride_w,
                        const int dilation_h, const int dilation_w,
                        const int parallel_imgs, Tensor data_col);
// shape检查
void my_conv_shape_check(at::Tensor input,
                         at::Tensor weight, int kH,
                         int kW, int dH, int dW, int padH, int padW,
                         int dilationH, int dilationW, int group)
{
    TORCH_CHECK(
        weight.ndimension() == 4,
        "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, but got: %s",
        weight.ndimension());

    TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");

    TORCH_CHECK(kW > 0 && kH > 0,
                "kernel size should be greater than zero, but got kH: %d kW: %d",
                kH, kW);

    TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
                "kernel size should be consistent with weight, ",
                "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d",
                kH, kW, weight.size(2), weight.size(3));

    TORCH_CHECK(dW > 0 && dH > 0,
                "stride should be greater than zero, but got dH: %d dW: %d", dH,
                dW);

    TORCH_CHECK(
        dilationW > 0 && dilationH > 0,
        "dilation should be greater than 0, but got dilationH: %d dilationW: %d",
        dilationH, dilationW);

    int ndim = input.ndimension();
    int dimf = 0;
    int dimh = 1;
    int dimw = 2;

    if (ndim == 4)
    {
        dimf++;
        dimh++;
        dimw++;
    }

    TORCH_CHECK(ndim == 3 || ndim == 4,
                "3D or 4D input tensor expected but got: %s", ndim);

    long nInputPlane = weight.size(1) * group;
    long inputHeight = input.size(dimh);
    long inputWidth = input.size(dimw);
    long nOutputPlane = weight.size(0);
    long outputHeight =
        (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
    long outputWidth =
        (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;

    if (outputWidth < 1 || outputHeight < 1)
        AT_ERROR(
            "Given input size: (%ld x %ld x %ld). "
            "Calculated output size: (%ld x %ld x %ld). Output size is too small",
            nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
            outputWidth);

    TORCH_CHECK(input.size(1) == nInputPlane,
                "invalid number of input planes, expected: %d, but got: %d",
                nInputPlane, input.size(1));

    TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
                "input image is smaller than kernel");
}
// my_conv_forward就是卷积的主函数。它的参数除了PyTorch的Conv2d传入的参数外,
// 还多了两个参数output, columus。这两个张量是保存中间结果的,在PyTorch侧是看不到的。
// output用于保存卷积输出,columns用于保存卷积时的列矩阵。
// 底层实现卷积时,会先把图像转换成一个用列表示的矩阵,
// 再把卷积操作当成一个矩阵乘法来完成。其中,第一步操作叫做"im2col"。
// 自定义卷积前向传播函数
void my_conv_forward(Tensor input, Tensor weight, Tensor bias,
                     Tensor output, Tensor columns, int kW,
                     int kH, int dW, int dH, int padW, int padH,
                     int dilationW, int dilationH, int group,
                     int im2col_step)
{
    // 检查输入是否为CUDA设备
    bool isCuda = false;
    if (input.device().is_cuda())
    {
        CHECK_CUDA_INPUT(input);
        CHECK_CUDA_INPUT(weight);
        CHECK_CUDA_INPUT(bias);
        CHECK_CUDA_INPUT(output);
        CHECK_CUDA_INPUT(columns);
        isCuda = true;
    }
    else
    {
        CHECK_CPU_INPUT(input);
        CHECK_CPU_INPUT(weight);
        CHECK_CPU_INPUT(bias);
        CHECK_CPU_INPUT(output);
        CHECK_CPU_INPUT(columns);
    }
    // 检查shape
    my_conv_shape_check(input, weight, kH, kW, dH, dW, padH,
                        padW, dilationH, dilationW, group);
    // 管理PyTorch中的设备上下文
    at::DeviceGuard guard(input.device());

    int batch = 1;
    // CHW->NCHW
    if (input.ndimension() == 3)
    {
        // 强制批量处理
        batch = 0;
        input.unsqueeze_(0);
    }

    long batchSize = input.size(0);//N
    long nInputPlane = input.size(1);//C
    long inputHeight = input.size(2);//H
    long inputWidth = input.size(3);//W

    long nOutputPlane = weight.size(0);

    // 计算输出尺寸
    // O=(W−K+2P)/S+1
    // inputWidth: 输入图像的宽度。
    // inputHeight: 输入图像的高度。
    // kW: 卷积核的宽度。
    // kH: 卷积核的高度。
    // dW: 在宽度方向上的步长。
    // dH: 在高度方向上的步长。
    // padW: 在宽度方向上的填充大小。
    // padH: 在高度方向上的填充大小。
    // dilationW: 在宽度方向上的膨胀率。
    // dilationH: 在高度方向上的膨胀率。
    long outputWidth =
        (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
    long outputHeight =
        (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;

    // 调整输出和列的形状
    output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
                          outputHeight, outputWidth});
    columns = at::zeros(
        {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
        input.options());

    input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
                        inputHeight, inputWidth});

    // 准备输出缓冲区
    Tensor output_buffer = at::zeros({batchSize / im2col_step, nOutputPlane,
                                      im2col_step * outputHeight, outputWidth},
                                     output.options());

    output_buffer = output_buffer.view(
        {output_buffer.size(0), group, output_buffer.size(1) / group,
         output_buffer.size(2), output_buffer.size(3)});

    // 主循环
    for (int elt = 0; elt < batchSize / im2col_step; elt++)
    {
        // 根据设备类型执行im2col
        if (isCuda)
        {
            my_conv_im2col_cuda(input[elt], nInputPlane, inputHeight,
                            inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
                            dilationW, im2col_step, columns);
        }
        else
        {
            my_conv_im2col_cpu(input[elt], nInputPlane, inputHeight,
                            inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
                            dilationW, im2col_step, columns);
        }
        
        // 调整列和权重的形状
        columns = columns.view({group, columns.size(0) / group, columns.size(1)});
        weight = weight.view({group, weight.size(0) / group, weight.size(1),
                              weight.size(2), weight.size(3)});

        // 分组卷积
        for (int g = 0; g < group; g++)
        {
            output_buffer[elt][g] = output_buffer[elt][g]
                                        .flatten(1)
                                        .addmm_(weight[g].flatten(1), columns[g])
                                        .view_as(output_buffer[elt][g]);
        }
        // 恢复形状
        columns =
            columns.view({columns.size(0) * columns.size(1), columns.size(2)});
        weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
                              weight.size(3), weight.size(4)});
    }

    // 最终调整输出缓冲区形状
    output_buffer = output_buffer.view(
        {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
         output_buffer.size(3), output_buffer.size(4)});

    output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
                                        im2col_step, outputHeight, outputWidth});
    output_buffer.transpose_(1, 2);

    // 复制并调整输出
    output.copy_(output_buffer);
    output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});

    // 加上偏置
    bias = bias.view({1, bias.size(0), 1, 1});
    output.add_(bias);

    // 恢复输入形状
    input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});

    // 特殊情况处理
    if (batch == 0)
    {
        output = output.view({nOutputPlane, outputHeight, outputWidth});
        input = input.view({nInputPlane, inputHeight, inputWidth});
    }
}
template <typename T>
/**
 * 二维卷积的im2col操作的CPU内核函数。
 * 该函数将图像数据按照卷积的参数展开成列,以便进行高效的矩阵乘法操作。
 * 
 * @param n 输出矩阵的元素数量。
 * @param data_im 输入图像数据指针。
 * @param height 输入图像的高度。
 * @param width 输入图像的宽度。
 * @param kernel_h 卷积核的高度。
 * @param kernel_w 卷积核的宽度。
 * @param pad_h 垂直方向上的填充大小。
 * @param pad_w 水平方向上的填充大小。
 * @param stride_h 垂直方向上的步长。
 * @param stride_w 水平方向上的步长。
 * @param dilation_h 垂直方向上的扩张大小。
 * @param dilation_w 水平方向上的扩张大小。
 * @param batch_size 批处理大小。
 * @param num_channels 输入图像的通道数。
 * @param height_col 输出矩阵的高度。
 * @param width_col 输出矩阵的宽度。
 * @param data_col 输出数据指针,存储展开后的列数据。
 */
void my_conv_im2col_cpu_kernel(
    const int n, const T *data_im, const int height,
    const int width, const int kernel_h, const int kernel_w, const int pad_h,
    const int pad_w, const int stride_h, const int stride_w,
    const int dilation_h, const int dilation_w,
    const int batch_size,
    const int num_channels, const int height_col,
    const int width_col, T *data_col)
{
    // 遍历输出矩阵的每个元素
    for (int index = 0; index < n; index++)
    {
        // 计算当前输出元素在宽度和高度上的位置,以及对应的批处理和通道索引
        const int w_col = index % width_col;
        const int h_col = (index / width_col) % height_col;
        const int b_col = (index / width_col / height_col) % batch_size;//batch
        const int c_im = (index / width_col / height_col) / batch_size;//channel,特征图
        const int c_col = c_im * kernel_h * kernel_w;//

        // 计算当前输出元素在输入图像上的起始位置
        const int h_in = h_col * stride_h - pad_h;
        const int w_in = w_col * stride_w - pad_w;

        // 计算当前输出元素在输出数据中的指针位置
        T *data_col_ptr =
            data_col +
            ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;

        // 计算当前输出元素在输入图像数据中的指针位置
        const T *data_im_ptr =
            data_im + (b_col * num_channels + c_im) * height * width;

        // 遍历卷积核的每个元素
        for (int i = 0; i < kernel_h; ++i)
        {
            for (int j = 0; j < kernel_w; ++j)
            {
                // 初始化当前卷积核元素的值
                T val = static_cast<T>(0);

                // 计算当前卷积核元素在输入图像上的位置
                const int h_im = h_in + i * dilation_h;
                const int w_im = w_in + j * dilation_w;

                // 如果当前位置在输入图像的有效范围内,则取出对应的像素值
                if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
                {
                    val = data_im_ptr[h_im * width + w_im];
                }

                // 将当前像素值写入到输出数据中,并更新指针位置
                *data_col_ptr = val;
                data_col_ptr += batch_size * height_col * width_col;
            }
        }
    }
}

void my_conv_im2col_cpu(Tensor data_im,
                        const int channels, const int height,
                        const int width, const int ksize_h,
                        const int ksize_w, const int pad_h, const int pad_w,
                        const int stride_h, const int stride_w,
                        const int dilation_h, const int dilation_w,
                        const int parallel_imgs, Tensor data_col)
{
    int height_col =
        (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
    int width_col =
        (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
    int num_kernels = channels * height_col * width_col * parallel_imgs;

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(
        data_im.scalar_type(), "", [&]
        { my_conv_im2col_cpu_kernel<scalar_t>(
              num_kernels, data_im.data_ptr<scalar_t>(),
              height, width, ksize_h, ksize_w,
              pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
              parallel_imgs, channels,
              height_col, width_col, data_col.data_ptr<scalar_t>()); });
}
// my_ops是生成的库名
// 代码块里m.def用于定义C++函数的Python接口。"my_conv_forward"是Python调用时的函数名称,
// my_conv_forward是被Python代码调用的这份代码里的C++函数名称。
// 也就是说,这份卷积实现的入口函数就是my_conv_forward。
PYBIND11_MODULE(my_ops, m)
{
      m.def("my_conv_forward", my_conv_forward, "my_conv_forward",
            py::arg("input"), py::arg("weight"), py::arg("bias"),
            py::arg("output"), py::arg("columns"), py::arg("kW"),
            py::arg("kH"), py::arg("dW"), py::arg("dH"), py::arg("padW"),
            py::arg("padH"), py::arg("dilationW"), py::arg("dilationH"),
            py::arg("group"), py::arg("im2col_step"));
}

6.my_conv_cuda.cu

// Modify from https://github.com/open-mmlab/mmcv/blob/my_conv/mmcv/ops/csrc/common/cuda/deform_conv_cuda_kernel.cuh
// Copyright (c) OpenMMLab. All rights reserved.
#include <torch/types.h>

#include "pytorch_cuda_helper.hpp"

template <typename T>
__global__ void my_conv_im2col_gpu_kernel(
    const int n, const T *data_im, const int height,
    const int width, const int kernel_h, const int kernel_w, const int pad_h,
    const int pad_w, const int stride_h, const int stride_w,
    const int dilation_h, const int dilation_w,
    const int batch_size,
    const int num_channels, const int height_col,
    const int width_col, T *data_col)
{
    CUDA_1D_KERNEL_LOOP(index, n)
    {
        // index index of output matrix
        const int w_col = index % width_col;
        const int h_col = (index / width_col) % height_col;
        const int b_col = (index / width_col / height_col) % batch_size;
        const int c_im = (index / width_col / height_col) / batch_size;
        const int c_col = c_im * kernel_h * kernel_w;

        const int h_in = h_col * stride_h - pad_h;
        const int w_in = w_col * stride_w - pad_w;
        T *data_col_ptr =
            data_col +
            ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
        const T *data_im_ptr =
            data_im + (b_col * num_channels + c_im) * height * width;

        for (int i = 0; i < kernel_h; ++i)
        {
            for (int j = 0; j < kernel_w; ++j)
            {
                T val = static_cast<T>(0);
                const int h_im = h_in + i * dilation_h;
                const int w_im = w_in + j * dilation_w;
                if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
                {
                    val = data_im_ptr[h_im * width + w_im];
                }
                *data_col_ptr = val;
                data_col_ptr += batch_size * height_col * width_col;
            }
        }
    }
}

void my_conv_im2col_cuda(Tensor data_im,
                         const int channels, const int height,
                         const int width, const int ksize_h,
                         const int ksize_w, const int pad_h, const int pad_w,
                         const int stride_h, const int stride_w,
                         const int dilation_h, const int dilation_w,
                         const int parallel_imgs, Tensor data_col)
{
    int height_col =
        (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
    int width_col =
        (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
    int num_kernels = channels * height_col * width_col * parallel_imgs;

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(
        data_im.scalar_type(), "my_conv_im2col_gpu", [&]
        { my_conv_im2col_gpu_kernel<scalar_t><<<GET_BLOCKS(num_kernels),
                                                THREADS_PER_BLOCK, 0,
                                                at::cuda::getCurrentCUDAStream()>>>(
              num_kernels, data_im.data_ptr<scalar_t>(),
              height, width, ksize_h, ksize_w,
              pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
              parallel_imgs, channels,
              height_col, width_col, data_col.data_ptr<scalar_t>()); });

    AT_CUDA_CHECK(cudaGetLastError());
}

7.my_conv.py

import torch
from torch.autograd import Function
import torch.nn as nn
from torch import Tensor
from torch.nn.modules.utils import _pair
from torch.nn.parameter import Parameter

import my_ops


class MyConvF(Function):

    @staticmethod
    def forward(ctx,
                input: torch.Tensor,
                weight,
                bias,
                stride=1,
                padding=0,
                dilation=1,
                groups=1,
                im2col_step=32):
        if input is not None and input.dim() != 4:
            raise ValueError(
                f'Expected 4D tensor as input, got {input.dim()}D tensor \
                  instead.')
        ctx.stride = _pair(stride)
        ctx.padding = _pair(padding)
        ctx.dilation = _pair(dilation)
        ctx.groups = groups
        ctx.im2col_step = im2col_step

        weight = weight.type_as(input)
        ctx.save_for_backward(input, weight)

        output = input.new_empty(MyConvF._output_size(ctx, input, weight))

        ctx.bufs_ = [input.new_empty(0), input.new_empty(0)]  # columns, ones

        cur_im2col_step = min(ctx.im2col_step, input.size(0))
        assert (input.size(0) % cur_im2col_step
                ) == 0, 'batch size must be divisible by im2col_step'

        if bias is None:
            bias = torch.zeros(weight.shape[0]).to(weight.device)

        my_ops.my_conv_forward(
            input,
            weight,
            bias,
            output,
            ctx.bufs_[0],
            kW=weight.size(3),
            kH=weight.size(2),
            dW=ctx.stride[1],
            dH=ctx.stride[0],
            padW=ctx.padding[1],
            padH=ctx.padding[0],
            dilationW=ctx.dilation[1],
            dilationH=ctx.dilation[0],
            group=ctx.groups,
            im2col_step=cur_im2col_step)
        return output

    @staticmethod
    def _output_size(ctx, input, weight):
        channels = weight.size(0)
        output_size = (input.size(0), channels)
        for d in range(input.dim() - 2):
            in_size = input.size(d + 2)
            pad = ctx.padding[d]
            kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1
            stride_ = ctx.stride[d]
            output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
        if not all(map(lambda s: s > 0, output_size)):
            raise ValueError(
                'convolution input is too small (output would be ' +
                'x'.join(map(str, output_size)) + ')')
        return output_size


my_conv = MyConvF.apply


class MyConv2d(nn.Module):

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups: int = 1,
                 bias: bool = True):
        super().__init__()
        kernel_size_ = _pair(kernel_size)
        stride_ = _pair(stride)
        padding_ = _pair(padding)
        dilation_ = _pair(dilation)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size_
        self.stride = stride_
        self.padding = padding_
        self.dilation = dilation_
        self.groups = groups
        self.weight = Parameter(
            torch.Tensor(out_channels, in_channels // groups, *kernel_size_))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        # Useless attributes
        self.transposed = None
        self.output_padding = None
        self.padding_mode = None

    def forward(self, input: Tensor) -> Tensor:
        return my_conv(input, self.weight, self.bias, self.stride,
                       self.padding, self.dilation, self.groups)

8.setup.py

from setuptools import setup
from torch.utils import cpp_extension
import os

src_root = './'
cpp_src = ['my_conv.cpp','my_conv_cuda.cu']

if __name__ == '__main__':
    include_dirs = ['./']
    cpp_path = [os.path.join(src_root, src) for src in cpp_src]

    setup(
        name='panoflow',
        ext_modules=[
            cpp_extension.CUDAExtension(
                'my_ops', cpp_path, include_dirs=include_dirs)
        ],
        cmdclass={'build_ext': cpp_extension.BuildExtension})

9.test.py

import torch
import torch.nn as nn
from my_conv import MyConv2d

inc = 3
outc = 4
img_shaspe = (50, 50)

device_name = 'cuda:0'
# device_name = 'cpu'
open_bias = True


def test_one():
    ts = torch.ones([1, 1, 3, 3]).to(device_name)
    layer = nn.Conv2d(1, 1, 3, 1, 1, bias=open_bias).to(device_name)
    gt = layer(ts)
    my_layer = MyConv2d(1, 1, 3, 1, 1).to(device_name)
    my_layer.load_state_dict(layer.state_dict(), strict=False)
    res = my_layer(ts)
    res = res.to('cpu')
    gt = gt.to('cpu')
    assert torch.allclose(res, gt, 1e-3, 1e-5)


def test_two():
    ts = torch.rand([1, inc, *img_shaspe]).to(device_name)
    layer = nn.Conv2d(inc, outc, 3, 1, 1, bias=open_bias).to(device_name)
    gt = layer(ts)
    my_layer = MyConv2d(inc, outc, 3, 1, 1).to(device_name)
    my_layer.load_state_dict(layer.state_dict(), strict=False)
    res = my_layer(ts)
    res = res.to('cpu')
    gt = gt.to('cpu')
    assert torch.allclose(res, gt, 1e-3, 1e-5)


if __name__ == '__main__':
    test_one()
    test_two()


http://www.kler.cn/news/356925.html

相关文章:

  • opencv环境配置-适配b站阿童木的opencv教程
  • 【前端】如何制作一个自己的网页(8)
  • 定时发送邮件
  • 51单片机快速入门之 LCD1602 液晶显示屏2024/10/19
  • STM32-USART串口协议
  • 【STL】string类的使用
  • 如何在博客中插入其他的博客链接(超简单)最新版
  • C语言刷题 LeetCode 删除单链表的重复节点 双指针法
  • 创客项目秀|基于XIAO ESP32C3的本地个人助理Mr.M
  • 突然猫毛过敏了怎么办?宠物空气净化器高效处理猫毛!
  • 关于目前面试八股文的一些心得体会
  • LeetCode 643.子数组最大平均数 I
  • SQL字段类型全解析:知识点、应用场景与长度说明
  • mysql多表关系与查询
  • MySQL 【日期】函数大全(七)
  • 深圳出手!新能源汽车被针对了
  • Android 取消充电动画logo,直接显示图片即可
  • linux线程 | 全面理解同步与互斥 | 同步
  • python+docxtpl:word文件模版渲染
  • 近期股市热潮,现有架构模块下金融交易系统如何应对“冲击”?优化思路如下