当前位置: 首页 > article >正文

常用的损失函数pytorch实现

梯度损失: 

先利用sobel算子计算梯度,然后计算计算出来梯度的一范数。

实现:
定义Sobelxy类。

class Sobelxy(nn.Module):
    def __init__(self):
        super(Sobelxy, self).__init__()
        kernelx = [[-1, 0, 1],
                  [-2,0 , 2],
                  [-1, 0, 1]]
        kernely = [[1, 2, 1],
                  [0,0 , 0],
                  [-1, -2, -1]]
        kernelx = torch.FloatTensor(kernelx).unsqueeze(0).unsqueeze(0)
        kernely = torch.FloatTensor(kernely).unsqueeze(0).unsqueeze(0)
        self.weightx = nn.Parameter(data=kernelx, requires_grad=False).cuda()
        self.weighty = nn.Parameter(data=kernely, requires_grad=False).cuda()
    def forward(self,x):
        sobelx=F.conv2d(x, self.weightx, padding=1)
        sobely=F.conv2d(x, self.weighty, padding=1)
        return torch.abs(sobelx)+torch.abs(sobely)

实例化类:

sobelconv = Sobelxy()

计算损失:

L_T_grad = sobelconv(L)
gradient_loss = beta * torch.sum(torch.abs(L_T_grad))

备注:计算的是L的梯度损失。

vgg感知损失:

模型的init中:

self.vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features.to(opts.device)

在训练的迭代中调用:

vgg_loss = torch.norm(model.vgg(R) - model.vgg(hat_R), p=1)

备注:计算的是用预训练好的vgg提取的R和hat_R的高级特征损失

结构损失:SSIM

调用封装好的SSIM基本都会失败,所以要自己写。
SSIM类:

class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True, val_range=None):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.val_range = val_range

        # Assume 1 channel for SSIM
        self.channel = 1
        self.window = create_window(window_size)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.dtype == img1.dtype:
            window = self.window
        else:
            window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
            self.window = window
            self.channel = channel

        return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)


类的实例化:

ssim_metric = SSIM(window_size=11, size_average=True, val_range=None)
ssim_value = ssim_metric(R, hat_R)

备注:算的是R和hat_R的结构损失。
 


http://www.kler.cn/a/389628.html

相关文章:

  • 基于ESP32+VUE+JAVA+Ngnix的一个小型固件编译系统
  • [Qt] Box Model | 控件样式 | 实现log_in界面
  • 【Python运维】用Python管理Docker容器:从`docker-py`到自动化部署的全面指南
  • TiDB 和 MySQL 的关系:这两者到底有什么不同和联系?
  • Excel 技巧10 - 如何检查输入重复数据(★★)
  • CV 图像处理基础笔记大全(超全版哦~)!!!
  • 批量清除Word Excel PPT文件打开密码
  • 让redis一直开启服务/自动启动
  • wordpress站外调用指定ID分类下的推荐内容
  • i2c-tools 4.3 for Android 9.0
  • stm32 ADC实例解析(3)-多通道采集互相干扰的问题
  • PySimpleGUI库和pymysql库
  • 探索计算机互联网的奇妙世界:从基础到前沿的无尽之旅
  • 2024 年 Java 面试正确姿势(1000+ 面试题附答案解析)
  • 算法学习第一弹——C++基础
  • Hive简介 | 体系结构
  • 青训3_1110_01 构造特定数组的逆序拼接
  • 性能飙升!时间序列+预训练强强联合,轻松迈入顶刊门槛!
  • conan2 c/c++包管理菜鸟入门
  • 使用MethodChannel与原生程序通信
  • PyQt5超详细教程终篇
  • 【Leecode】Leecode刷题之路第46天之全排列
  • InnoDB存储引擎对MVCC的实现
  • 项目管理平台盘点:2024推荐的9款优质工具
  • NLP自然语言处理:深入探索Self-Attention——自注意力机制详解
  • C语言 | Leetcode C语言题解之第551题学生出勤记录I