当前位置: 首页 > article >正文

低光增强常用的损失函数pytorch实现

低照度图像增强模型——损失函数整理汇总_低照度sunshihanshu-CSDN博客

大多数搬运自以上文章,非原创。 

梯度损失: 

先利用sobel算子计算梯度,然后计算计算出来梯度的一范数。

实现:
定义Sobelxy类。

class Sobelxy(nn.Module):
    def __init__(self):
        super(Sobelxy, self).__init__()
        kernelx = [[-1, 0, 1],
                  [-2,0 , 2],
                  [-1, 0, 1]]
        kernely = [[1, 2, 1],
                  [0,0 , 0],
                  [-1, -2, -1]]
        kernelx = torch.FloatTensor(kernelx).unsqueeze(0).unsqueeze(0)
        kernely = torch.FloatTensor(kernely).unsqueeze(0).unsqueeze(0)
        self.weightx = nn.Parameter(data=kernelx, requires_grad=False).cuda()
        self.weighty = nn.Parameter(data=kernely, requires_grad=False).cuda()
    def forward(self,x):
        sobelx=F.conv2d(x, self.weightx, padding=1)
        sobely=F.conv2d(x, self.weighty, padding=1)
        return torch.abs(sobelx)+torch.abs(sobely)

实例化类:

sobelconv = Sobelxy()

计算损失:

L_T_grad = sobelconv(L)
gradient_loss = beta * torch.sum(torch.abs(L_T_grad))

备注:计算的是L的梯度损失。

vgg感知损失:

import torch
import torch.nn.functional as F
from torchvision import models
# 加载预训练的VGG模型
vgg = models.vgg19(pretrained=True).features
# 将模型设置为评估模式
vgg.eval()
def perceptual_loss(out_img, gt_img):
    # 计算输入和目标图像在VGG特征图上的差异
    input_features = vgg(out_img)
    target_features = vgg(gt_img)
    loss = F.mse_loss(input_features, target_features)
    return loss


# 如果要从指定路径加载模型权重
def load_vgg(weights_path : str):
    vgg = models.vgg16(pretrained=False)         # 加载 VGG16 / VGG19 模型的特征提取部分,不使用预训练权重
    state_dict = torch.load(weights_path)  	 			 # 加载权重
    vgg.load_state_dict(state_dict)   					 # 将加载的权重加载到模型中
    vgg = vgg.features
    return vgg

vgg = load_vgg(r'E:\预训练权重\vgg16-397923af.pth')
vgg.eval()			# 将模型设置为评估模式

def perceptual_loss(out_img, gt_img):
    # 计算输入和目标图像在 VGG 特征图上的差异
    input_features = vgg(out_img)
    target_features = vgg(gt_img)
    loss = F.mse_loss(input_features, target_features)
    return loss
perceptual_loss(a,b)

结构损失:SSIM

调用封装好的SSIM基本都会失败,所以要自己写。
SSIM类:

class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True, val_range=None):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.val_range = val_range

        # Assume 1 channel for SSIM
        self.channel = 1
        self.window = create_window(window_size)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.dtype == img1.dtype:
            window = self.window
        else:
            window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
            self.window = window
            self.channel = channel

        return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)


类的实例化:

ssim_metric = SSIM(window_size=11, size_average=True, val_range=None)
ssim_value = ssim_metric(R, hat_R)

备注:算的是R和hat_R的结构损失。

L1损失函数 / 平均绝对误差(MAE)

衡量经过模型增强后的输出图像与GT图像之间的误差。

# 第一种:
import torch.nn.functional as F
F.l1_loss(out_image , gt_image)


# 第二种:
import torch.nn as nn
nn.L1Loss()(out_image , gt_image)


# 第三种:
def MAELoss(out_image, gt_image):
    return torch.mean(torch.abs(out_image - gt_image))

平滑度损失TV_loss

class TVLoss(nn.Module):
    def __init__(self):
        super(TVLoss,self).__init__()

    def forward(self,x,weight_map=None):
        self.h_x = x.size()[2]
        self.w_x = x.size()[3]
        self.batch_size = x.size()[0]
        if weight_map is None:
            self.TVLoss_weight=(1, 1)
        else:
            self.TVLoss_weight = self.compute_weight(weight_map)

        count_h = self._tensor_size(x[:,:,1:,:])
        count_w = self._tensor_size(x[:,:,:,1:])

        h_tv = (self.TVLoss_weight[0]*torch.abs((x[:,:,1:,:]-x[:,:,:self.h_x-1,:]))).sum()
        w_tv = (self.TVLoss_weight[1]*torch.abs((x[:,:,:,1:]-x[:,:,:,:self.w_x-1]))).sum()
        # print(self.TVLoss_weight[0],self.TVLoss_weight[1])
        return (h_tv/count_h+w_tv/count_w)/self.batch_size

    def _tensor_size(self,t):
        return t.size()[1]*t.size()[2]*t.size()[3]

    def compute_weight(self, img):
        gradx = torch.abs(img[:, :, 1:, :] - img[:, :, :self.h_x-1, :])
        grady = torch.abs(img[:, :, :, 1:] - img[:, :, :, :self.w_x-1])
        TVLoss_weight_x = torch.div(1,torch.exp(gradx))
        TVLoss_weight_y = torch.div(1, torch.exp(grady))

        return TVLoss_weight_x, TVLoss_weight_y

# 简洁版
class L_TV(nn.Module):
    def __init__(self):
        super(L_TV,self).__init__()

    def forward(self,x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h =  (x.size()[2]-1) * x.size()[3]
        count_w = x.size()[2] * (x.size()[3] - 1)
        h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
        w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
        return 2*(h_tv/count_h+w_tv/count_w)/batch_size


# DSLR中的
def tv_loss(img):
    """
    Compute total variation loss.
    Inputs:
    - img: PyTorch Variable of shape (1, 3, H, W) holding an input image.
    - tv_weight: Scalar giving the weight w_t to use for the TV loss.
    Returns:
    - loss: PyTorch Variable holding a scalar giving the total variation loss
      for img weighted by tv_weight.
    """
    b,c,h,w_ = img.size()
    w_variance = torch.sum(torch.pow(img[:,:,:,:-1] - img[:,:,:,1:], 2))/b
    h_variance = torch.sum(torch.pow(img[:,:,:-1,:] - img[:,:,1:,:], 2))/b
    loss = (h_variance + w_variance) / 2
    return loss

 无监督平滑度损失:

 

# RetinexNet
import torch
import torch.nn.functional as F
class Smoothloss():
	def __init__(self):
        super(Smoothloss, self).__init__()
        
	def gradient(self, input_tensor, direction):
	    self.smooth_kernel_x = torch.FloatTensor([[0, 0], [-1, 1]]).view((1, 1, 2, 2)).cuda()
	    self.smooth_kernel_y = torch.transpose(self.smooth_kernel_x, 2, 3)
	
	    if direction == "x":
	        kernel = self.smooth_kernel_x
	    elif direction == "y":
	        kernel = self.smooth_kernel_y
	    grad_out = torch.abs(F.conv2d(input_tensor, kernel,
	                                  stride=1, padding=1))
	    return grad_out
	    
	def ave_gradient(self, input_tensor, direction):
	    return F.avg_pool2d(self.gradient(input_tensor, direction),
	                        kernel_size=3, stride=1, padding=1)
	
	def smooth(self, I_img, R_img):
	    R_img = 0.299*R_img[:, 0, :, :] + 0.587*R_img[:, 1, :, :] + 0.114*R_img[:, 2, :, :]  # 转换到YUV空间中的Y通道
	    R_img = torch.unsqueeze(R_img, dim=1)
	    return torch.mean(self.gradient(I_img, "x") * torch.exp(-10 * self.ave_gradient(R_img, "x")) +
	                      self.gradient(I_img, "y") * torch.exp(-10 * self.ave_gradient(R_img, "y")))

颜色损失(Color Constancy Loss) 适合无监督

 

class L_color(nn.Module):
 
    def __init__(self):
        super(L_color, self).__init__()
 
    def forward(self, x ):
 
        b,c,h,w = x.shape
 
        mean_rgb = torch.mean(x,[2,3],keepdim=True)
        mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
        Drg = torch.pow(mr-mg,2)
        Drb = torch.pow(mr-mb,2)
        Dgb = torch.pow(mb-mg,2)
        k = torch.pow(torch.pow(Drg,2) + torch.pow(Drb,2) + torch.pow(Dgb,2),0.5)
        return k

DIV中的颜色损失:适合有监督 

 

def angle(a, b):
    vector = torch.mul(a, b)
    up     = torch.sum(vector)
    down   = torch.sqrt(torch.sum(torch.square(a))) * torch.sqrt(torch.sum(torch.square(b)))
    theta  = torch.acos(up/down) # 弧度制
    return theta
def color_loss(out_image, gt_image): # 颜色损失  希望增强前后图片的颜色一致性 (b,c,h,w)
    loss = torch.mean(angle(out_image[:,0,:,:],gt_image[:,0,:,:]) + 
                      angle(out_image[:,1,:,:],gt_image[:,1,:,:]) +
                      angle(out_image[:,2,:,:],gt_image[:,2,:,:]))
    return loss

曝光损失(Exposure Loss)

class L_exp(nn.Module):
    def __init__(self,patch_size = 16,mean_val = 0.6):   # 如果图像像素值在0~255,mean_valy应该是255*0.6
        super(L_exp, self).__init__()
        # print(1)
        self.pool = nn.AvgPool2d(patch_size)
        self.mean_val = mean_val
        
    def forward(self, x ):
        b,c,h,w = x.shape
        x = torch.mean(x,1,keepdim=True)
        mean = self.pool(x)
 
        loss = torch.mean(torch.pow(mean- torch.FloatTensor([self.mean_val] ).cuda(),2))
        return loss

 


http://www.kler.cn/a/395286.html

相关文章:

  • 【postman】怎么通过curl看请求报什么错
  • 进入未来城:第五周游戏指南
  • goframe开发一个企业网站 统一返回响应码 18
  • 万字长文解读深度学习——生成对抗网络GAN
  • [CKS] K8S Dockerfile和yaml文件安全检测
  • MySQL(5)【数据类型 —— 字符串类型】
  • 「QT」高阶篇 之 d-指针 的用法
  • javascript用来干嘛的?赋予网站灵魂的语言
  • axios平替!用浏览器自带的fetch处理AJAX(兼容表单/JSON/文件上传)
  • 百度世界2024|李彦宏:智能体是AI应用的最主流形态,即将迎来爆发点
  • 应用jar包使用skywalking8(Tongweb7嵌入式p11版本 by lqw)
  • uniapp 如何使用vuex store (亲测)
  • 游戏引擎学习第二天
  • 深入理解 Spring Boot 中的 Starters
  • vue3+ant design vue实现日期等选择器点击右上角叉号默认将值变为null,此时会影响查询等操作~
  • 【C++】隐含的“This指针“
  • GIT将源码推送新分支
  • 第十四章 Spring之假如让你来写AOP——雏形篇
  • 二分查找--快速地将搜索空间减半
  • 大语言模型在序列推荐中的应用
  • MinIo在Ubantu和Java中的整合
  • 某军工变压器企业:通过集团级工业IOT平台,实现数字化转型
  • yakit远程连接(引擎部署在vps上)
  • PyAEDT:Ansys Electronics Desktop API 简介
  • Apache Doris:快速入门与实践
  • word转markdown的方法(pandoc)