低光增强常用的损失函数pytorch实现
低照度图像增强模型——损失函数整理汇总_低照度sunshihanshu-CSDN博客
大多数搬运自以上文章,非原创。
梯度损失:
先利用sobel算子计算梯度,然后计算计算出来梯度的一范数。
实现:
定义Sobelxy类。
class Sobelxy(nn.Module):
def __init__(self):
super(Sobelxy, self).__init__()
kernelx = [[-1, 0, 1],
[-2,0 , 2],
[-1, 0, 1]]
kernely = [[1, 2, 1],
[0,0 , 0],
[-1, -2, -1]]
kernelx = torch.FloatTensor(kernelx).unsqueeze(0).unsqueeze(0)
kernely = torch.FloatTensor(kernely).unsqueeze(0).unsqueeze(0)
self.weightx = nn.Parameter(data=kernelx, requires_grad=False).cuda()
self.weighty = nn.Parameter(data=kernely, requires_grad=False).cuda()
def forward(self,x):
sobelx=F.conv2d(x, self.weightx, padding=1)
sobely=F.conv2d(x, self.weighty, padding=1)
return torch.abs(sobelx)+torch.abs(sobely)
实例化类:
sobelconv = Sobelxy()
计算损失:
L_T_grad = sobelconv(L)
gradient_loss = beta * torch.sum(torch.abs(L_T_grad))
备注:计算的是L的梯度损失。
vgg感知损失:
import torch
import torch.nn.functional as F
from torchvision import models
# 加载预训练的VGG模型
vgg = models.vgg19(pretrained=True).features
# 将模型设置为评估模式
vgg.eval()
def perceptual_loss(out_img, gt_img):
# 计算输入和目标图像在VGG特征图上的差异
input_features = vgg(out_img)
target_features = vgg(gt_img)
loss = F.mse_loss(input_features, target_features)
return loss
# 如果要从指定路径加载模型权重
def load_vgg(weights_path : str):
vgg = models.vgg16(pretrained=False) # 加载 VGG16 / VGG19 模型的特征提取部分,不使用预训练权重
state_dict = torch.load(weights_path) # 加载权重
vgg.load_state_dict(state_dict) # 将加载的权重加载到模型中
vgg = vgg.features
return vgg
vgg = load_vgg(r'E:\预训练权重\vgg16-397923af.pth')
vgg.eval() # 将模型设置为评估模式
def perceptual_loss(out_img, gt_img):
# 计算输入和目标图像在 VGG 特征图上的差异
input_features = vgg(out_img)
target_features = vgg(gt_img)
loss = F.mse_loss(input_features, target_features)
return loss
perceptual_loss(a,b)
结构损失:SSIM
调用封装好的SSIM基本都会失败,所以要自己写。
SSIM类:
class SSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, val_range=None):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.val_range = val_range
# Assume 1 channel for SSIM
self.channel = 1
self.window = create_window(window_size)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.dtype == img1.dtype:
window = self.window
else:
window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
self.window = window
self.channel = channel
return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
类的实例化:
ssim_metric = SSIM(window_size=11, size_average=True, val_range=None)
ssim_value = ssim_metric(R, hat_R)
备注:算的是R和hat_R的结构损失。
L1损失函数 / 平均绝对误差(MAE)
衡量经过模型增强后的输出图像与GT图像之间的误差。
# 第一种:
import torch.nn.functional as F
F.l1_loss(out_image , gt_image)
# 第二种:
import torch.nn as nn
nn.L1Loss()(out_image , gt_image)
# 第三种:
def MAELoss(out_image, gt_image):
return torch.mean(torch.abs(out_image - gt_image))
平滑度损失TV_loss
class TVLoss(nn.Module):
def __init__(self):
super(TVLoss,self).__init__()
def forward(self,x,weight_map=None):
self.h_x = x.size()[2]
self.w_x = x.size()[3]
self.batch_size = x.size()[0]
if weight_map is None:
self.TVLoss_weight=(1, 1)
else:
self.TVLoss_weight = self.compute_weight(weight_map)
count_h = self._tensor_size(x[:,:,1:,:])
count_w = self._tensor_size(x[:,:,:,1:])
h_tv = (self.TVLoss_weight[0]*torch.abs((x[:,:,1:,:]-x[:,:,:self.h_x-1,:]))).sum()
w_tv = (self.TVLoss_weight[1]*torch.abs((x[:,:,:,1:]-x[:,:,:,:self.w_x-1]))).sum()
# print(self.TVLoss_weight[0],self.TVLoss_weight[1])
return (h_tv/count_h+w_tv/count_w)/self.batch_size
def _tensor_size(self,t):
return t.size()[1]*t.size()[2]*t.size()[3]
def compute_weight(self, img):
gradx = torch.abs(img[:, :, 1:, :] - img[:, :, :self.h_x-1, :])
grady = torch.abs(img[:, :, :, 1:] - img[:, :, :, :self.w_x-1])
TVLoss_weight_x = torch.div(1,torch.exp(gradx))
TVLoss_weight_y = torch.div(1, torch.exp(grady))
return TVLoss_weight_x, TVLoss_weight_y
# 简洁版
class L_TV(nn.Module):
def __init__(self):
super(L_TV,self).__init__()
def forward(self,x):
batch_size = x.size()[0]
h_x = x.size()[2]
w_x = x.size()[3]
count_h = (x.size()[2]-1) * x.size()[3]
count_w = x.size()[2] * (x.size()[3] - 1)
h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
return 2*(h_tv/count_h+w_tv/count_w)/batch_size
# DSLR中的
def tv_loss(img):
"""
Compute total variation loss.
Inputs:
- img: PyTorch Variable of shape (1, 3, H, W) holding an input image.
- tv_weight: Scalar giving the weight w_t to use for the TV loss.
Returns:
- loss: PyTorch Variable holding a scalar giving the total variation loss
for img weighted by tv_weight.
"""
b,c,h,w_ = img.size()
w_variance = torch.sum(torch.pow(img[:,:,:,:-1] - img[:,:,:,1:], 2))/b
h_variance = torch.sum(torch.pow(img[:,:,:-1,:] - img[:,:,1:,:], 2))/b
loss = (h_variance + w_variance) / 2
return loss
无监督平滑度损失:
# RetinexNet
import torch
import torch.nn.functional as F
class Smoothloss():
def __init__(self):
super(Smoothloss, self).__init__()
def gradient(self, input_tensor, direction):
self.smooth_kernel_x = torch.FloatTensor([[0, 0], [-1, 1]]).view((1, 1, 2, 2)).cuda()
self.smooth_kernel_y = torch.transpose(self.smooth_kernel_x, 2, 3)
if direction == "x":
kernel = self.smooth_kernel_x
elif direction == "y":
kernel = self.smooth_kernel_y
grad_out = torch.abs(F.conv2d(input_tensor, kernel,
stride=1, padding=1))
return grad_out
def ave_gradient(self, input_tensor, direction):
return F.avg_pool2d(self.gradient(input_tensor, direction),
kernel_size=3, stride=1, padding=1)
def smooth(self, I_img, R_img):
R_img = 0.299*R_img[:, 0, :, :] + 0.587*R_img[:, 1, :, :] + 0.114*R_img[:, 2, :, :] # 转换到YUV空间中的Y通道
R_img = torch.unsqueeze(R_img, dim=1)
return torch.mean(self.gradient(I_img, "x") * torch.exp(-10 * self.ave_gradient(R_img, "x")) +
self.gradient(I_img, "y") * torch.exp(-10 * self.ave_gradient(R_img, "y")))
颜色损失(Color Constancy Loss) 适合无监督
class L_color(nn.Module):
def __init__(self):
super(L_color, self).__init__()
def forward(self, x ):
b,c,h,w = x.shape
mean_rgb = torch.mean(x,[2,3],keepdim=True)
mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
Drg = torch.pow(mr-mg,2)
Drb = torch.pow(mr-mb,2)
Dgb = torch.pow(mb-mg,2)
k = torch.pow(torch.pow(Drg,2) + torch.pow(Drb,2) + torch.pow(Dgb,2),0.5)
return k
DIV中的颜色损失:适合有监督
def angle(a, b):
vector = torch.mul(a, b)
up = torch.sum(vector)
down = torch.sqrt(torch.sum(torch.square(a))) * torch.sqrt(torch.sum(torch.square(b)))
theta = torch.acos(up/down) # 弧度制
return theta
def color_loss(out_image, gt_image): # 颜色损失 希望增强前后图片的颜色一致性 (b,c,h,w)
loss = torch.mean(angle(out_image[:,0,:,:],gt_image[:,0,:,:]) +
angle(out_image[:,1,:,:],gt_image[:,1,:,:]) +
angle(out_image[:,2,:,:],gt_image[:,2,:,:]))
return loss
曝光损失(Exposure Loss)
class L_exp(nn.Module):
def __init__(self,patch_size = 16,mean_val = 0.6): # 如果图像像素值在0~255,mean_valy应该是255*0.6
super(L_exp, self).__init__()
# print(1)
self.pool = nn.AvgPool2d(patch_size)
self.mean_val = mean_val
def forward(self, x ):
b,c,h,w = x.shape
x = torch.mean(x,1,keepdim=True)
mean = self.pool(x)
loss = torch.mean(torch.pow(mean- torch.FloatTensor([self.mean_val] ).cuda(),2))
return loss