当前位置: 首页 > article >正文

常用的损失函数pytorch实现

梯度损失: 

先利用sobel算子计算梯度,然后计算计算出来梯度的一范数。

实现:
定义Sobelxy类。

class Sobelxy(nn.Module):
    def __init__(self):
        super(Sobelxy, self).__init__()
        kernelx = [[-1, 0, 1],
                  [-2,0 , 2],
                  [-1, 0, 1]]
        kernely = [[1, 2, 1],
                  [0,0 , 0],
                  [-1, -2, -1]]
        kernelx = torch.FloatTensor(kernelx).unsqueeze(0).unsqueeze(0)
        kernely = torch.FloatTensor(kernely).unsqueeze(0).unsqueeze(0)
        self.weightx = nn.Parameter(data=kernelx, requires_grad=False).cuda()
        self.weighty = nn.Parameter(data=kernely, requires_grad=False).cuda()
    def forward(self,x):
        sobelx=F.conv2d(x, self.weightx, padding=1)
        sobely=F.conv2d(x, self.weighty, padding=1)
        return torch.abs(sobelx)+torch.abs(sobely)

实例化类:

sobelconv = Sobelxy()

计算损失:

L_T_grad = sobelconv(L)
gradient_loss = beta * torch.sum(torch.abs(L_T_grad))

备注:计算的是L的梯度损失。

vgg感知损失:

模型的init中:

self.vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features.to(opts.device)

在训练的迭代中调用:

vgg_loss = torch.norm(model.vgg(R) - model.vgg(hat_R), p=1)

备注:计算的是用预训练好的vgg提取的R和hat_R的高级特征损失

结构损失:SSIM

调用封装好的SSIM基本都会失败,所以要自己写。
SSIM类:

class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True, val_range=None):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.val_range = val_range

        # Assume 1 channel for SSIM
        self.channel = 1
        self.window = create_window(window_size)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.dtype == img1.dtype:
            window = self.window
        else:
            window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
            self.window = window
            self.channel = channel

        return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)


类的实例化:

ssim_metric = SSIM(window_size=11, size_average=True, val_range=None)
ssim_value = ssim_metric(R, hat_R)

备注:算的是R和hat_R的结构损失。
 


http://www.kler.cn/a/389628.html

相关文章:

  • 【测试框架篇】单元测试框架pytest(1):环境安装和配置
  • golang分布式缓存项目 Day1 LRU 缓存淘汰策略
  • PostgreSQL中的COPY命令:高效数据导入与导出
  • GEE 数据集——美国gNATSGO(网格化国家土壤调查地理数据库)完整覆盖了美国所有地区和岛屿领土的最佳可用土壤信息
  • Android 下内联汇编,Android Studio 汇编开发
  • 【贪心算法】No.1---贪心算法(1)
  • 批量清除Word Excel PPT文件打开密码
  • 让redis一直开启服务/自动启动
  • wordpress站外调用指定ID分类下的推荐内容
  • i2c-tools 4.3 for Android 9.0
  • stm32 ADC实例解析(3)-多通道采集互相干扰的问题
  • PySimpleGUI库和pymysql库
  • 探索计算机互联网的奇妙世界:从基础到前沿的无尽之旅
  • 2024 年 Java 面试正确姿势(1000+ 面试题附答案解析)
  • 算法学习第一弹——C++基础
  • Hive简介 | 体系结构
  • 青训3_1110_01 构造特定数组的逆序拼接
  • 性能飙升!时间序列+预训练强强联合,轻松迈入顶刊门槛!
  • conan2 c/c++包管理菜鸟入门
  • 使用MethodChannel与原生程序通信
  • PyQt5超详细教程终篇
  • 【Leecode】Leecode刷题之路第46天之全排列
  • InnoDB存储引擎对MVCC的实现
  • 项目管理平台盘点:2024推荐的9款优质工具
  • NLP自然语言处理:深入探索Self-Attention——自注意力机制详解
  • C语言 | Leetcode C语言题解之第551题学生出勤记录I