当前位置: 首页 > article >正文



  1. 理解子像素卷积

  2. 子像素卷积的实现

import torch
import torch.nn as nn

class SubPixelConv(nn.Module):
    def __init__(self, upscale_factor):
        super(SubPixelConv, self).__init__()
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

    def forward(self, x):
        return self.pixel_shuffle(x)

# Example usage
input_tensor = torch.randn(1, 16, 8, 8)  # Batch size 1, 16 channels, 8x8 feature map
upscale_factor = 2
model = SubPixelConv(upscale_factor)
output_tensor = model(input_tensor)
print(output_tensor.shape)  # Output shape will be (1, 4, 16, 16)
  1. CUDA优化
#include <cuda_runtime.h>

__global__ void subpixel_conv_kernel(float* input, float* output, int channels, int height, int width, int upscale_factor) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int total_elements = channels * height * width;

    if (idx < total_elements) {
        int c = idx / (height * width);
        int h = (idx % (height * width)) / width;
        int w = idx % width;

        int new_c = c / (upscale_factor * upscale_factor);
        int offset = c % (upscale_factor * upscale_factor);
        int new_h = h * upscale_factor + offset / upscale_factor;
        int new_w = w * upscale_factor + offset % upscale_factor;

        int new_idx = new_c * (height * upscale_factor) * (width * upscale_factor) + new_h * (width * upscale_factor) + new_w;
        output[new_idx] = input[idx];

extern "C" void subpixel_conv(float* input, float* output, int channels, int height, int width, int upscale_factor) {
    int total_elements = channels * height * width;
    int threads_per_block = 1024;
    int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block;

    subpixel_conv_kernel<<<num_blocks, threads_per_block>>>(input, output, channels, height, width, upscale_factor);
  1. 调用CUDA内核
import numpy as np
import cupy as cp

# Load the CUDA kernel
kernel_code = """
extern "C" __global__ void subpixel_conv_kernel(float* input, float* output, int channels, int height, int width, int upscale_factor) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int total_elements = channels * height * width;

    if (idx < total_elements) {
        int c = idx / (height * width);
        int h = (idx % (height * width)) / width;
        int w = idx % width;

        int new_c = c / (upscale_factor * upscale_factor);
        int offset = c % (upscale_factor * upscale_factor);
        int new_h = h * upscale_factor + offset / upscale_factor;
        int new_w = w * upscale_factor + offset % upscale_factor;

        int new_idx = new_c * (height * upscale_factor) * (width * upscale_factor) + new_h * (width * upscale_factor) + new_w;
        output[new_idx] = input[idx];

module = cp.RawModule(code=kernel_code)
subpixel_conv_kernel = module.get_function("subpixel_conv_kernel")

def subpixel_conv(input_tensor, upscale_factor):
    channels, height, width = input_tensor.shape
    output_tensor = cp.zeros((channels // (upscale_factor * upscale_factor), height * upscale_factor, width * upscale_factor), dtype=cp.float32)
    threads_per_block = 1024
    total_elements = channels * height * width
    num_blocks = (total_elements + threads_per_block - 1) // threads_per_block
    subpixel_conv_kernel((num_blocks,), (threads_per_block,), (input_tensor, output_tensor, channels, height, width, upscale_factor))
    return output_tensor

# Example usage
input_tensor = cp.random.randn(16, 8, 8).astype(cp.float32)
upscale_factor = 2
output_tensor = subpixel_conv(input_tensor, upscale_factor)
print(output_tensor.shape)  # Output shape will be (4, 16, 16)



  • vscode 中快捷生成模板快捷键
  • C#-使用VisualStudio编译C#工程
  • 【区块链】以太坊
  • 【go】函数类型的作用
  • 多线程(超详细) (ε≡٩(๑>₃<)۶ 一心向学)
  • 数据库技术
  • C++内建函数对象
  • QT基础十四、绘图
  • RabbitMQ (Java)学习笔记
  • 高安全可靠MCU芯片AS32X601应用解析
  • 【北上广深杭大厂AI算法面试题】计算机视觉篇...YOLOv4 相较于YOLOv3有哪些改进?速度更快还是更慢,为什么?(二)
  • 关于Django框架的面试题及其解析
  • 嵌入式2-按键
  • Python网络爬虫之BeautifulSoup库的使用流程和方法
  • 力扣hot100_二叉树(5)_python版本
  • 实验5 逻辑回归
  • 梯度下降法以及随机梯度下降法
  • 作业9 (2023-05-05 数组的定义和初始化)
  • 富文本编辑器(Rich Text Editor,RTE)
  • 矩阵交换行(信息学奥赛一本通-1119)