PyTorch框架——基于深度学习EfficientDeRain神经网络AI去雨滴图像增强系统
第一步:EfficientDeRain介绍
EfficientDeRain 是一个针对单张图像去雨的开源项目,该项目由清华大学的研究团队提出,主要用于处理图像中的雨水干扰,恢复图像的真实场景
核心功能
图像去雨:EfficientDeRain 通过学习像素级的膨胀滤波,有效去除图像中的雨水干扰,恢复清晰图像。
高效率:项目设计考虑到了效率,能够在较短的时间内处理大量图像,适用于需要快速处理的应用场景。
可扩展性:项目提供了多种数据集的预训练模型,支持自定义数据集的训练,方便用户根据具体需求进行优化。
第二步:LYT-Net网络结构
该算法的原理非常简单,最重要的思想是把去雨看为图像的逐像素滤波问题。而滤波操作是高度优化的操作,在GPU上的实现必定非常快。
看懂下面这张图,即可完全理解作者的算法思想:
图像经深度卷积网络学习逐像素的卷积核参数,然后与原图做卷积即得最终的去雨后图像,训练的时候需要(有雨、无雨)的图像对。
作者指出,尽管上述思想没有问题,但因为逐像素卷积核大小的问题,如果只学习普通卷积核(即每个像素预测三个通道的3x(3x3)个参数)如上图中的(a)部分,对于雨条较大的图像很难取得满意的效果,因为毕竟卷积的过程是寻找周围非雨条像素赋以高权重的加权,卷积核如果没有覆盖到非雨条像素,肯定效果不好。
为在尺度上应对大雨条,所以作者做了改进,让神经网络预测多尺度的空洞卷积核,如(b)子图中是预测4个尺度的空洞卷积核,空洞卷积后再把结果加权,获得最终的去雨图像。
所以算法的核心思路可总结为:学习多尺度空洞卷积+图像加权融合
第三步:模型代码展示
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# ----------------------------------------
# Initialize the networks
# ----------------------------------------
def weights_init(net, init_type = 'normal', init_gain = 0.02):
"""Initialize network weights.
Parameters:
net (network) -- network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
init_gain (float) -- scaling factor for normal, xavier and orthogonal
In our paper, we choose the default setting: zero mean Gaussian distribution with a standard deviation of 0.02
"""
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and classname.find('Conv') != -1:
if init_type == 'normal':
torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
torch.nn.init.xavier_normal_(m.weight.data, gain = init_gain)
elif init_type == 'kaiming':
torch.nn.init.kaiming_normal_(m.weight.data, a = 0, mode = 'fan_in')
elif init_type == 'orthogonal':
torch.nn.init.orthogonal_(m.weight.data, gain = init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
elif classname.find('BatchNorm2d') != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
# apply the initialization function <init_func>
print('initialize network with %s type' % init_type)
net.apply(init_func)
# ----------------------------------------
# Kernel Prediction Network (KPN)
# ----------------------------------------
class Basic(nn.Module):
def __init__(self, in_ch, out_ch, g=16, channel_att=False, spatial_att=False):
super(Basic, self).__init__()
self.channel_att = channel_att
self.spatial_att = spatial_att
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1),
# nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1),
# nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1),
# nn.BatchNorm2d(out_ch),
nn.ReLU()
)
if channel_att:
self.att_c = nn.Sequential(
nn.Conv2d(2*out_ch, out_ch//g, 1, 1, 0),
nn.ReLU(),
nn.Conv2d(out_ch//g, out_ch, 1, 1, 0),
nn.Sigmoid()
)
if spatial_att:
self.att_s = nn.Sequential(
nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3),
nn.Sigmoid()
)
def forward(self, data):
"""
Forward function.
:param data:
:return: tensor
"""
fm = self.conv1(data)
if self.channel_att:
# fm_pool = F.adaptive_avg_pool2d(fm, (1, 1)) + F.adaptive_max_pool2d(fm, (1, 1))
fm_pool = torch.cat([F.adaptive_avg_pool2d(fm, (1, 1)), F.adaptive_max_pool2d(fm, (1, 1))], dim=1)
att = self.att_c(fm_pool)
fm = fm * att
if self.spatial_att:
fm_pool = torch.cat([torch.mean(fm, dim=1, keepdim=True), torch.max(fm, dim=1, keepdim=True)[0]], dim=1)
att = self.att_s(fm_pool)
fm = fm * att
return fm
class KPN(nn.Module):
def __init__(self, color=True, burst_length=1, blind_est=True, kernel_size=[5], sep_conv=False,
channel_att=False, spatial_att=False, upMode='bilinear', core_bias=False):
super(KPN, self).__init__()
self.upMode = upMode
self.burst_length = burst_length
self.core_bias = core_bias
self.color_channel = 3 if color else 1
in_channel = (3 if color else 1) * (burst_length if blind_est else burst_length+1)
out_channel = (3 if color else 1) * (2 * sum(kernel_size) if sep_conv else np.sum(np.array(kernel_size) ** 2)) * burst_length
if core_bias:
out_channel += (3 if color else 1) * burst_length
# 各个卷积层定义
# 2~5层都是均值池化+3层卷积
self.conv1 = Basic(in_channel, 64, channel_att=False, spatial_att=False)
self.conv2 = Basic(64, 128, channel_att=False, spatial_att=False)
self.conv3 = Basic(128, 256, channel_att=False, spatial_att=False)
self.conv4 = Basic(256, 512, channel_att=False, spatial_att=False)
self.conv5 = Basic(512, 512, channel_att=False, spatial_att=False)
# 6~8层要先上采样再卷积
self.conv6 = Basic(512+512, 512, channel_att=channel_att, spatial_att=spatial_att)
self.conv7 = Basic(256+512, 256, channel_att=channel_att, spatial_att=spatial_att)
self.conv8 = Basic(256+128, out_channel, channel_att=channel_att, spatial_att=spatial_att)
self.outc = nn.Conv2d(out_channel, out_channel, 1, 1, 0)
self.kernel_pred = KernelConv(kernel_size, sep_conv, self.core_bias)
self.conv_final = nn.Conv2d(in_channels=12, out_channels=3, kernel_size=3, stride=1, padding=1)
# 前向传播函数
def forward(self, data_with_est, data, white_level=1.0):
"""
forward and obtain pred image directly
:param data_with_est: if not blind estimation, it is same as data
:param data:
:return: pred_img_i and img_pred
"""
conv1 = self.conv1(data_with_est)
conv2 = self.conv2(F.avg_pool2d(conv1, kernel_size=2, stride=2))
conv3 = self.conv3(F.avg_pool2d(conv2, kernel_size=2, stride=2))
conv4 = self.conv4(F.avg_pool2d(conv3, kernel_size=2, stride=2))
conv5 = self.conv5(F.avg_pool2d(conv4, kernel_size=2, stride=2))
# 开始上采样 同时要进行skip connection
conv6 = self.conv6(torch.cat([conv4, F.interpolate(conv5, scale_factor=2, mode=self.upMode)], dim=1))
conv7 = self.conv7(torch.cat([conv3, F.interpolate(conv6, scale_factor=2, mode=self.upMode)], dim=1))
#print(conv7.size())
conv8 = self.conv8(torch.cat([conv2, F.interpolate(conv7, scale_factor=2, mode=self.upMode)], dim=1))
# return channel K*K*N
core = self.outc(F.interpolate(conv8, scale_factor=2, mode=self.upMode))
pred1 = self.kernel_pred(data, core, white_level, rate=1)
pred2 = self.kernel_pred(data, core, white_level, rate=2)
pred3 = self.kernel_pred(data, core, white_level, rate=3)
pred4 = self.kernel_pred(data, core, white_level, rate=4)
pred_cat = torch.cat([torch.cat([torch.cat([pred1, pred2], dim=1), pred3], dim=1), pred4], dim=1)
pred = self.conv_final(pred_cat)
#pred = self.kernel_pred(data, core, white_level, rate=1)
return pred
class KernelConv(nn.Module):
"""
the class of computing prediction
"""
def __init__(self, kernel_size=[5], sep_conv=False, core_bias=False):
super(KernelConv, self).__init__()
self.kernel_size = sorted(kernel_size)
self.sep_conv = sep_conv
self.core_bias = core_bias
def _sep_conv_core(self, core, batch_size, N, color, height, width):
"""
convert the sep_conv core to conv2d core
2p --> p^2
:param core: shape: batch*(N*2*K)*height*width
:return:
"""
kernel_total = sum(self.kernel_size)
core = core.view(batch_size, N, -1, color, height, width)
if not self.core_bias:
core_1, core_2 = torch.split(core, kernel_total, dim=2)
else:
core_1, core_2, core_3 = torch.split(core, kernel_total, dim=2)
# output core
core_out = {}
cur = 0
for K in self.kernel_size:
t1 = core_1[:, :, cur:cur + K, ...].view(batch_size, N, K, 1, 3, height, width)
t2 = core_2[:, :, cur:cur + K, ...].view(batch_size, N, 1, K, 3, height, width)
core_out[K] = torch.einsum('ijklno,ijlmno->ijkmno', [t1, t2]).view(batch_size, N, K * K, color, height, width)
cur += K
# it is a dict
return core_out, None if not self.core_bias else core_3.squeeze()
def _convert_dict(self, core, batch_size, N, color, height, width):
"""
make sure the core to be a dict, generally, only one kind of kernel size is suitable for the func.
:param core: shape: batch_size*(N*K*K)*height*width
:return: core_out, a dict
"""
core_out = {}
core = core.view(batch_size, N, -1, color, height, width)
core_out[self.kernel_size[0]] = core[:, :, 0:self.kernel_size[0]**2, ...]
bias = None if not self.core_bias else core[:, :, -1, ...]
return core_out, bias
def forward(self, frames, core, white_level=1.0, rate=1):
"""
compute the pred image according to core and frames
:param frames: [batch_size, N, 3, height, width]
:param core: [batch_size, N, dict(kernel), 3, height, width]
:return:
"""
if len(frames.size()) == 5:
batch_size, N, color, height, width = frames.size()
else:
batch_size, N, height, width = frames.size()
color = 1
frames = frames.view(batch_size, N, color, height, width)
if self.sep_conv:
core, bias = self._sep_conv_core(core, batch_size, N, color, height, width)
else:
core, bias = self._convert_dict(core, batch_size, N, color, height, width)
img_stack = []
pred_img = []
kernel = self.kernel_size[::-1]
for index, K in enumerate(kernel):
if not img_stack:
padding_num = (K//2) * rate
frame_pad = F.pad(frames, [padding_num, padding_num, padding_num, padding_num])
for i in range(0, K):
for j in range(0, K):
img_stack.append(frame_pad[..., i*rate:i*rate + height, j*rate:j*rate + width])
img_stack = torch.stack(img_stack, dim=2)
else:
k_diff = (kernel[index - 1] - kernel[index]) // 2
img_stack = img_stack[:, :, k_diff:-k_diff, ...]
# print('img_stack:', img_stack.size())
pred_img.append(torch.sum(
core[K].mul(img_stack), dim=2, keepdim=False
))
pred_img = torch.stack(pred_img, dim=0)
# print('pred_stack:', pred_img.size())
pred_img_i = torch.mean(pred_img, dim=0, keepdim=False)
#print("pred_img_i", pred_img_i.size())
# N = 1
pred_img_i = pred_img_i.squeeze(2)
#print("pred_img_i", pred_img_i.size())
# if bias is permitted
if self.core_bias:
if bias is None:
raise ValueError('The bias should not be None.')
pred_img_i += bias
# print('white_level', white_level.size())
pred_img_i = pred_img_i / white_level
#pred_img = torch.mean(pred_img_i, dim=1, keepdim=True)
# print('pred_img:', pred_img.size())
# print('pred_img_i:', pred_img_i.size())
return pred_img_i
class LossFunc(nn.Module):
"""
loss function of KPN
"""
def __init__(self, coeff_basic=1.0, coeff_anneal=1.0, gradient_L1=True, alpha=0.9998, beta=100):
super(LossFunc, self).__init__()
self.coeff_basic = coeff_basic
self.coeff_anneal = coeff_anneal
self.loss_basic = LossBasic(gradient_L1)
self.loss_anneal = LossAnneal(alpha, beta)
def forward(self, pred_img_i, pred_img, ground_truth, global_step):
"""
forward function of loss_func
:param frames: frame_1 ~ frame_N, shape: [batch, N, 3, height, width]
:param core: a dict coverted by ......
:param ground_truth: shape [batch, 3, height, width]
:param global_step: int
:return: loss
"""
return self.coeff_basic * self.loss_basic(pred_img, ground_truth), self.coeff_anneal * self.loss_anneal(global_step, pred_img_i, ground_truth)
class LossBasic(nn.Module):
"""
Basic loss function.
"""
def __init__(self, gradient_L1=True):
super(LossBasic, self).__init__()
self.l1_loss = nn.L1Loss()
self.l2_loss = nn.MSELoss()
self.gradient = TensorGradient(gradient_L1)
def forward(self, pred, ground_truth):
return self.l2_loss(pred, ground_truth) + \
self.l1_loss(self.gradient(pred), self.gradient(ground_truth))
class LossAnneal(nn.Module):
"""
anneal loss function
"""
def __init__(self, alpha=0.9998, beta=100):
super(LossAnneal, self).__init__()
self.global_step = 0
self.loss_func = LossBasic(gradient_L1=True)
self.alpha = alpha
self.beta = beta
def forward(self, global_step, pred_i, ground_truth):
"""
:param global_step: int
:param pred_i: [batch_size, N, 3, height, width]
:param ground_truth: [batch_size, 3, height, width]
:return:
"""
loss = 0
for i in range(pred_i.size(1)):
loss += self.loss_func(pred_i[:, i, ...], ground_truth)
loss /= pred_i.size(1)
return self.beta * self.alpha ** global_step * loss
class TensorGradient(nn.Module):
"""
the gradient of tensor
"""
def __init__(self, L1=True):
super(TensorGradient, self).__init__()
self.L1 = L1
def forward(self, img):
w, h = img.size(-2), img.size(-1)
l = F.pad(img, [1, 0, 0, 0])
r = F.pad(img, [0, 1, 0, 0])
u = F.pad(img, [0, 0, 1, 0])
d = F.pad(img, [0, 0, 0, 1])
if self.L1:
return torch.abs((l - r)[..., 0:w, 0:h]) + torch.abs((u - d)[..., 0:w, 0:h])
else:
return torch.sqrt(
torch.pow((l - r)[..., 0:w, 0:h], 2) + torch.pow((u - d)[..., 0:w, 0:h], 2)
)
if __name__ == '__main__':
kpn = KPN().cuda()
a = torch.randn(4, 3, 224, 224).cuda()
b = kpn(a, a)
print(b.shape)
第四步:运行
第五步:整个工程的内容
项目完整文件下载请见演示与介绍视频的简介处给出:➷➷➷
PyTorch框架——基于深度学习EfficientDeRain神经网络AI去雨滴图像增强系统_哔哩哔哩_bilibili