基于卷积神经网络和 Swin Transformer 的图像处理模型
实现了一个基于卷积神经网络和 Swin Transformer 的图像处理模型。该模型主要用于对输入图像进行特征提取和处理,以实现特定的图像任务,如图像增强、去噪等。
项目完整代码下载链接:https://download.csdn.net/download/huanghm88/89909179
import torch
from torch import nn
import torch.nn.functional as F
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
##**项目完整代码下载链接:https://download.csdn.net/download/huanghm88/89909179**
class Convlutioanl(nn.Module):
def __init__(self, in_channel, out_channel):
super(Convlutioanl, self).__init__()
self.padding=(2,2,2,2)
self.conv=nn.Conv2d(in_channel,out_channel,kernel_size=5,padding=0,stride=1)
self.bn=nn.BatchNorm2d(out_channel)
self.relu=nn.ReLU(inplace=True)
def forward(self, input):
out=F.pad(input,self.padding,'replicate')
out=self.conv(out)
out=self.bn(out)
out=self.relu(out)
return out
class Convlutioanl_out(nn.Module):
def __init__(self, in_channel, out_channel):
super(Convlutioanl_out, self).__init__()
self.conv=nn.Conv2d(in_channel,out_channel,kernel_size=1,padding=0,stride=1)
self.tanh=nn.Tanh()
def forward(self, input):
out=self.conv(input)
out=self.tanh(out)
return out
class Fem(nn.Module):
def __init__(self, in_channel, out_channel):
super(Fem, self).__init__()
self.padding = (1, 1, 1, 1)
self.conv=nn.Conv2d(in_channel,out_channel,kernel_size=3,padding=0,stride=1)
self.bn=nn.BatchNorm2d(out_channel)
self.relu=nn.ReLU(inplace=True)
def forward(self, input):
out = F.pad(input, self.padding, 'replicate')
out=self.conv(out)
out=self.bn(out)
out=self.relu(out)
out = F.pad(out, self.padding, 'replicate')
out=self.conv(out)
out = self.bn(out)
return out
class Channel_attention(nn.Module):
def __init__(self, channel, reduction=4):
super(Channel_attention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc=nn.Sequential(
nn.Conv2d(channel,channel//reduction,1),
nn.ReLU(inplace=True),
nn.Conv2d(channel//reduction,channel,1))
self.sigmoid=nn.Sigmoid()
def forward(self, input):
out=self.avg_pool(input)
out=self.fc(out)
out=self.sigmoid(out)
return out
class Spatial_attention(nn.Module):
def __init__(self, channel, reduction=4):
super(Spatial_attention, self).__init__()
self.body=nn.Sequential(
nn.Conv2d(channel, channel//reduction,3,padding=1),
nn.BatchNorm2d( channel//reduction),
nn.ReLU(True),
nn.Conv2d(channel // reduction, 1, 3, padding=1),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
def forward(self, input):
return self.body(input)
class WindowAttention(nn.Module):
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.relative_position_bias_table = nn