常用的损失函数pytorch实现
梯度损失:
先利用sobel算子计算梯度,然后计算计算出来梯度的一范数。
实现:
定义Sobelxy类。
class Sobelxy(nn.Module):
def __init__(self):
super(Sobelxy, self).__init__()
kernelx = [[-1, 0, 1],
[-2,0 , 2],
[-1, 0, 1]]
kernely = [[1, 2, 1],
[0,0 , 0],
[-1, -2, -1]]
kernelx = torch.FloatTensor(kernelx).unsqueeze(0).unsqueeze(0)
kernely = torch.FloatTensor(kernely).unsqueeze(0).unsqueeze(0)
self.weightx = nn.Parameter(data=kernelx, requires_grad=False).cuda()
self.weighty = nn.Parameter(data=kernely, requires_grad=False).cuda()
def forward(self,x):
sobelx=F.conv2d(x, self.weightx, padding=1)
sobely=F.conv2d(x, self.weighty, padding=1)
return torch.abs(sobelx)+torch.abs(sobely)
实例化类:
sobelconv = Sobelxy()
计算损失:
L_T_grad = sobelconv(L)
gradient_loss = beta * torch.sum(torch.abs(L_T_grad))
备注:计算的是L的梯度损失。
vgg感知损失:
模型的init中:
self.vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features.to(opts.device)
在训练的迭代中调用:
vgg_loss = torch.norm(model.vgg(R) - model.vgg(hat_R), p=1)
备注:计算的是用预训练好的vgg提取的R和hat_R的高级特征损失
结构损失:SSIM
调用封装好的SSIM基本都会失败,所以要自己写。
SSIM类:
class SSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, val_range=None):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.val_range = val_range
# Assume 1 channel for SSIM
self.channel = 1
self.window = create_window(window_size)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.dtype == img1.dtype:
window = self.window
else:
window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
self.window = window
self.channel = channel
return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
类的实例化:
ssim_metric = SSIM(window_size=11, size_average=True, val_range=None)
ssim_value = ssim_metric(R, hat_R)
备注:算的是R和hat_R的结构损失。