子像素卷积是一种通过重排列特征图的通道来实现上采样的方法。它通过将低分辨率特征图的多个通道重新排列成高分辨率图像的空间维度,从而实现上采样。 -
import torch
import torch.nn as nn
class SubPixelConv(nn.Module):
def __init__(self, upscale_factor):
super(SubPixelConv, self).__init__()
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
return self.pixel_shuffle(x)
# Example usage
input_tensor = torch.randn(1, 16, 8, 8) # Batch size 1, 16 channels, 8x8 feature map
upscale_factor = 2
model = SubPixelConv(upscale_factor)
output_tensor = model(input_tensor)
print(output_tensor.shape) # Output shape will be (1, 4, 16, 16)
- CUDA优化
#include <cuda_runtime.h>
__global__ void subpixel_conv_kernel(float* input, float* output, int channels, int height, int width, int upscale_factor) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total_elements = channels * height * width;
if (idx < total_elements) {
int c = idx / (height * width);
int h = (idx % (height * width)) / width;
int w = idx % width;
int new_c = c / (upscale_factor * upscale_factor);
int offset = c % (upscale_factor * upscale_factor);
int new_h = h * upscale_factor + offset / upscale_factor;
int new_w = w * upscale_factor + offset % upscale_factor;
int new_idx = new_c * (height * upscale_factor) * (width * upscale_factor) + new_h * (width * upscale_factor) + new_w;
output[new_idx] = input[idx];
extern "C" void subpixel_conv(float* input, float* output, int channels, int height, int width, int upscale_factor) {
int total_elements = channels * height * width;
int threads_per_block = 1024;
int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block;
subpixel_conv_kernel<<<num_blocks, threads_per_block>>>(input, output, channels, height, width, upscale_factor);
- 调用CUDA内核
import numpy as np
import cupy as cp
# Load the CUDA kernel
kernel_code = """
extern "C" __global__ void subpixel_conv_kernel(float* input, float* output, int channels, int height, int width, int upscale_factor) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total_elements = channels * height * width;
if (idx < total_elements) {
int c = idx / (height * width);
int h = (idx % (height * width)) / width;
int w = idx % width;
int new_c = c / (upscale_factor * upscale_factor);
int offset = c % (upscale_factor * upscale_factor);
int new_h = h * upscale_factor + offset / upscale_factor;
int new_w = w * upscale_factor + offset % upscale_factor;
int new_idx = new_c * (height * upscale_factor) * (width * upscale_factor) + new_h * (width * upscale_factor) + new_w;
output[new_idx] = input[idx];
module = cp.RawModule(code=kernel_code)
subpixel_conv_kernel = module.get_function("subpixel_conv_kernel")
def subpixel_conv(input_tensor, upscale_factor):
channels, height, width = input_tensor.shape
output_tensor = cp.zeros((channels // (upscale_factor * upscale_factor), height * upscale_factor, width * upscale_factor), dtype=cp.float32)
threads_per_block = 1024
total_elements = channels * height * width
num_blocks = (total_elements + threads_per_block - 1) // threads_per_block
subpixel_conv_kernel((num_blocks,), (threads_per_block,), (input_tensor, output_tensor, channels, height, width, upscale_factor))
return output_tensor
# Example usage
input_tensor = cp.random.randn(16, 8, 8).astype(cp.float32)
upscale_factor = 2
output_tensor = subpixel_conv(input_tensor, upscale_factor)
print(output_tensor.shape) # Output shape will be (4, 16, 16)