当前位置: 首页 > article >正文

深度学习之超分辨率算法——SRGAN

  • 更新版本

  • 实现了生成对抗网络在超分辨率上的使用

  • 更新了损失函数,增加先验函数
    在这里插入图片描述

  • SRresnet实现

import torch
import torchvision
from torch import nn


class ConvBlock(nn.Module):

	def __init__(self, kernel_size=3, stride=1, n_inchannels=64):
		super(ConvBlock, self).__init__()

		self.sequential = nn.Sequential(
			nn.Conv2d(in_channels=n_inchannels, out_channels=n_inchannels, kernel_size=(kernel_size, kernel_size),
					  stride=(stride, stride), bias=False, padding=(1, 1)),
			nn.BatchNorm2d(n_inchannels),
			nn.PReLU(),
			nn.Conv2d(in_channels=n_inchannels, out_channels=n_inchannels, kernel_size=(kernel_size, kernel_size),
					  stride=(stride, stride), bias=False, padding=(1, 1)),
			nn.BatchNorm2d(n_inchannels),
			nn.PReLU(),
		)

	def forward(self, x):
		redisious = x
		out = self.sequential(x)
		return redisious + out


class Head_Conv(nn.Module):

	def __init__(self):
		super(Head_Conv, self).__init__()
		self.sequential = nn.Sequential(
			nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(9, 9), stride=(1, 1), padding=(9 // 2, 9 // 2)),
			nn.PReLU(),
		)

	def forward(self, x):
		return self.sequential(x)


class PixelShuffle(nn.Module):

	def __init__(self, n_channels=64, upscale_factor=2):
		super(PixelShuffle, self).__init__()
		self.sequential = nn.Sequential(
			nn.Conv2d(in_channels=n_channels, out_channels=n_channels * (upscale_factor ** 2), kernel_size=(3, 3),
					  stride=(1, 1), padding=(3 // 2, 3 // 2)),
			nn.BatchNorm2d(n_channels * (upscale_factor ** 2)),
			nn.PixelShuffle(upscale_factor=upscale_factor)
		)

	def forward(self, x):
		return self.sequential(x)


class Hidden_block(nn.Module):

	def __init__(self):
		super(Hidden_block, self).__init__()
		self.sequential = nn.Sequential(
			nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),
			nn.BatchNorm2d(64),
		)

	def forward(self, x):
		return self.sequential(x)


class TailConv(nn.Module):

	def __init__(self):
		super(TailConv, self).__init__()
		self.sequential = nn.Sequential(
			nn.Conv2d(in_channels=64, out_channels=3, kernel_size=(9, 9), stride=(1, 1), padding=(9 // 2, 9 // 2)),
			nn.Tanh(),
		)

	def forward(self, x):
		return self.sequential(x)


class SRResNet(nn.Module):

	def __init__(self, n_blocks=16):
		super(SRResNet, self).__init__()
		self.head = Head_Conv()
		self.resnet = list()
		for _ in range(n_blocks):
			self.resnet.append(ConvBlock(kernel_size=3, stride=1, n_inchannels=64))

		self.resnet = nn.Sequential(*self.resnet)
		self.hidden = Hidden_block()
		self.pixelShuufe = []
		for _ in range(2):
			self.pixelShuufe.append(
				PixelShuffle(n_channels=64, upscale_factor=2)
			)
		self.pixelShuufe = nn.Sequential(*self.pixelShuufe)
		self.tail_conv = TailConv()

	def forward(self, x):
		head_out = self.head(x)
		resnet_out = self.resnet(head_out)
		out = head_out + resnet_out
		result = self.pixelShuufe(out)
		out = self.tail_conv(result)
		return out

class Generator(nn.Module):

	def __init__(self):
		super(Generator, self).__init__()
		self.model = SRResNet()

	def forward(self, x):
		'''
		:param x:lr_img
		:return: 
		'''
		return self.model(x)


class Discriminator(nn.Module):

	def __init__(self):
		super(Discriminator, self).__init__()
		self.hidden = nn.Sequential(
			nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),
			nn.LeakyReLU(),
			nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
			nn.BatchNorm2d(64),
			nn.LeakyReLU(),

			nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),
			nn.BatchNorm2d(128),
			nn.LeakyReLU(),
			nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
			nn.BatchNorm2d(128),
			nn.LeakyReLU(),

			nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),
			nn.BatchNorm2d(256),
			nn.LeakyReLU(),
			nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
			nn.BatchNorm2d(256),
			nn.LeakyReLU(),

			nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),
			nn.BatchNorm2d(512),
			nn.LeakyReLU(),
			nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
			nn.BatchNorm2d(512),
			nn.LeakyReLU(),
			nn.AdaptiveAvgPool2d((6, 6))
		)
		self.out_layer = nn.Sequential(
			nn.Linear(512 * 6 * 6, 1024),
			nn.LeakyReLU(negative_slope=0.2, inplace=True),
			nn.Linear(1024, 1),
			nn.Sigmoid()
		)

	def forward(self, x):
		result = self.hidden(x)
		# print(result.shape)
		result = result.reshape(result.shape[0], -1)
		out = self.out_layer(result)
		return out

SRGAN模型的生成器与判别器的实现


class Generator(nn.Module):

	def __init__(self):
		super(Generator, self).__init__()
		self.model = SRResNet()

	def forward(self, x):
		'''
		:param x:lr_img
		:return: 
		'''
		return self.model(x)


class Discriminator(nn.Module):

	def __init__(self):
		super(Discriminator, self).__init__()
		self.hidden = nn.Sequential(
			nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),
			nn.LeakyReLU(),
			nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
			nn.BatchNorm2d(64),
			nn.LeakyReLU(),

			nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),
			nn.BatchNorm2d(128),
			nn.LeakyReLU(),
			nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
			nn.BatchNorm2d(128),
			nn.LeakyReLU(),

			nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),
			nn.BatchNorm2d(256),
			nn.LeakyReLU(),
			nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
			nn.BatchNorm2d(256),
			nn.LeakyReLU(),

			nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),
			nn.BatchNorm2d(512),
			nn.LeakyReLU(),
			nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
			nn.BatchNorm2d(512),
			nn.LeakyReLU(),
			nn.AdaptiveAvgPool2d((6, 6))
		)
		self.out_layer = nn.Sequential(
			nn.Linear(512 * 6 * 6, 1024),
			nn.LeakyReLU(negative_slope=0.2, inplace=True),
			nn.Linear(1024, 1),
			nn.Sigmoid()
		)

	def forward(self, x):
		result = self.hidden(x)
		# print(result.shape)
		result = result.reshape(result.shape[0], -1)
		out = self.out_layer(result)
		return out


```
- 针对VGG19 的层数截取
```python
class TruncatedVGG19(nn.Module):
	"""
	truncated VGG19网络,用于计算VGG特征空间的MSE损失
	"""

	def __init__(self, i, j):
		"""
		:参数 i: 第 i 个池化层
		:参数 j: 第 j 个卷积层
		"""
		super(TruncatedVGG19, self).__init__()

		# 加载预训练的VGG模型
		vgg19 = torchvision.models.vgg19(pretrained=True)
		print(vgg19)
		maxpool_counter = 0
		conv_count = 0
		truncate_at = 0
		# 迭代搜索
		for layer in vgg19.features.children():
			truncate_at += 1

			# 统计
			if isinstance(layer, nn.Conv2d):
				conv_count += 1
			if isinstance(layer, nn.MaxPool2d):
				maxpool_counter += 1
				conv_counter = 0

			# 截断位置在第(i-1)个池化层之后(第 i 个池化层之前)的第 j 个卷积层
			if maxpool_counter == i - 1 and conv_count == j:
				break

		# 检查是否满足条件
		assert maxpool_counter == i - 1 and conv_count == j, "当前 i=%d 、 j=%d 不满足 VGG19 模型结构" % (
			i, j)

		# 截取网络
		self.truncated_vgg19 = nn.Sequential(*list(vgg19.features.children())[:truncate_at + 1])

	def forward(self, input):
		output = self.truncated_vgg19(input)  # (N, channels, _w,h)

		return output
```



http://www.kler.cn/a/447261.html

相关文章:

  • 【TF-IDF】Hugging Face Model Recommendation System
  • 最新ubuntu20.04安装docker流畅教程
  • fastdds:idl
  • JS信息收集(小迪网络安全笔记~
  • web3跨链桥协议-Nomad
  • Mamba安装环境和使用,anaconda环境打包
  • 【ETCD】【源码阅读】深入分析 applierV3backend.Apply`方法源码
  • 设计模式之桥接模式:抽象与实现之间的分离艺术
  • C语言 排序
  • sqlite基础
  • 安卓native层抓trace
  • ASP.NET Core - 依赖注入 自动批量注入
  • gesp(三级)(8)洛谷:B3926:[GESP202312 三级] 单位转换
  • HarmonyOS(71) 自定义事件分发之TouchTestStrategy使用说明
  • 电脑提示ntdll.d缺失是什么原因?不处理的话会怎么样?ntdll.dll文件缺失快速解决方案来啦!
  • 数据云平台的可观测性
  • 【Leecode】Leecode刷题之路第87天之扰乱字符串
  • SKETCHPAD——允许语言模型生成中间草图,在几何、函数、图算法和游戏策略等所有数学任务中持续提高基础模型的性能
  • ip_output函数
  • 音视频学习(二十六):http-flv
  • Docker搭建kafka环境
  • 线性分类器(KNN,SVM损失,交叉熵损失,softmax)
  • 微信小程序-生成骨架屏
  • nbcio-vue版本第一次登录出现404问题
  • Docker安全性与最佳实践
  • Hive其五,使用技巧,数据查询,日志以及复杂类型的使用