当前位置: 首页 > article >正文

深度学习之超分辨率算法——SRCNN

  • 网络为基础卷积层

  • tensorflow 1.14

  • scipy 1.2.1

  • numpy 1.16

  • 大概意思就是针对数据,我们先把图片按缩小因子照整数倍进行缩减为小图片,再针对小图片进行插值算法,获得还原后的低分辨率的图片作为标签。

  • main.py 配置文件

from model import SRCNN
from utils import input_setup
import numpy as np
import tensorflow as tf
import pprint
import os

flags = tf.app.flags
# 设置轮次
flags.DEFINE_integer("epoch", 1000, "Number of epoch [1000]")
# 设置批次
flags.DEFINE_integer("batch_size", 128, "The size of batch images [128]")
# 设置image大小
flags.DEFINE_integer("image_size", 33, "The size of image to use [33]")
# 设置label
flags.DEFINE_integer("label_size", 21, "The size of label to produce [21]")
# 学习率
flags.DEFINE_float("learning_rate", 1e-4, "The learning rate of gradient descent algorithm [1e-4]")
# 图像颜色的尺寸
flags.DEFINE_integer("c_dim", 1, "Dimension of image color. [1]")
# 对输入图像进行预处理的比例因子大小
flags.DEFINE_integer("scale", 3, "The size of scale factor for preprocessing input image [3]")
# 步长
flags.DEFINE_integer("stride", 14, "The size of stride to apply input image [14]")
# 权重位置
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Name of checkpoint directory [checkpoint]")
# 样本目录
flags.DEFINE_string("sample_dir", "sample", "Name of sample directory [sample]")
# 训练还是测试
flags.DEFINE_boolean("is_train", False, "True for training, False for testing [True]")
FLAGS = flags.FLAGS

# 格式化打印
pp = pprint.PrettyPrinter()

def main(_):
    #   打印参数
    pp.pprint(flags.FLAGS.__flags)

    # 没有就新建~
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    # Session提供了Operation执行和Tensor求值的环境;
    with tf.Session() as sess:
        srcnn = SRCNN(sess,
                      image_size=FLAGS.image_size,
                      label_size=FLAGS.label_size,
                      batch_size=FLAGS.batch_size,
                      c_dim=FLAGS.c_dim,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      sample_dir=FLAGS.sample_dir)

        srcnn.train(FLAGS)
    
if __name__ == '__main__':
  tf.app.run()
from utils import (
    read_data,
    input_setup,
    imsave,
    merge
)
import time
import os
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

try:
    xrange
except:
    xrange = range


class SRCNN(object):
    # 模型初始化
    def __init__(self,
                 sess,
                 image_size=33,
                 label_size=21,
                 batch_size=128,
                 c_dim=1,
                 checkpoint_dir=None,
                 sample_dir=None):

        self.sess = sess
        # 判断灰度图
        self.is_grayscale = (c_dim == 1)

        self.image_size = image_size
        self.label_size = label_size
        self.batch_size = batch_size

        self.c_dim = c_dim

        self.checkpoint_dir = checkpoint_dir
        self.sample_dir = sample_dir

        self.build_model()

    def build_model(self):
        # tf.placeholder(
        # dtype,
        # shape = None,
        # name = None
        # )
        # 定义image,labels 输入形式 N W H C
        self.images = tf.placeholder(dtype=tf.float32, shape=[None, self.image_size, self.image_size, self.c_dim], name='images')
        self.labels = tf.placeholder(tf.float32, [None, self.label_size, self.label_size, self.c_dim], name='labels')
        # tf.Variable(initializer, name), 参数initializer是初始化参数,name是可自定义的变量名称,
        # shape为[filter_height, filter_width, in_channel, out_channels]
        # 构建模型参数
        self.weights = {
            'w1': tf.Variable(initial_value=tf.random_normal([9, 9, 1, 64], stddev=1e-3), name='w1'),
            'w2': tf.Variable(initial_value=tf.random_normal([1, 1, 64, 32], stddev=1e-3), name='w2'),
            'w3': tf.Variable(initial_value=tf.random_normal([5, 5, 32, 1], stddev=1e-3), name='w3')
        }
        # the dim of bias== c_dim
        self.biases = {
            'b1': tf.Variable(tf.zeros([64]), name='b1'),
            'b2': tf.Variable(tf.zeros([32]), name='b2'),
            'b3': tf.Variable(tf.zeros([1]), name='b3')
        }
        # 构建模型 返回MHWC
        self.pred = self.model()

        # Loss function (MSE)
        self.loss = tf.reduce_mean(tf.square(self.labels - self.pred))
        # 保存和加载模型
        # 如果只想保留最新的4个模型,并希望每2个小时保存一次,
        self.saver = tf.train.Saver(max_to_keep=4,keep_checkpoint_every_n_hours=2)

    def train(self, config):

        if config.is_train:
            # 训练状态
            input_setup(self.sess, config)
        else:
            nx, ny = input_setup(self.sess, config)


        if config.is_train:

            data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "train.h5")

        else:

            data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "test.h5")


        train_data, train_label = read_data(data_dir)

        # Stochastic gradient descent with the standard backpropagation
        self.train_op = tf.train.GradientDescentOptimizer(config.learning_rate).minimize(self.loss)

        tf.initialize_all_variables().run()

        counter = 0
        start_time = time.time()

        if self.load(self.checkpoint_dir):
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        if config.is_train:
            print("Training...")

            for ep in xrange(config.epoch):
                # Run by batch images
                batch_idxs = len(train_data) // config.batch_size
                for idx in xrange(0, batch_idxs):
                    batch_images = train_data[idx * config.batch_size: (idx + 1) * config.batch_size]
                    batch_labels = train_label[idx * config.batch_size: (idx + 1) * config.batch_size]

                    counter += 1
                    _, err = self.sess.run([self.train_op, self.loss],
                                           feed_dict={self.images: batch_images, self.labels: batch_labels})

                    if counter % 10 == 0:
                        print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" \
                              % ((ep + 1), counter, time.time() - start_time, err))

                    if counter % 500 == 0:
                        self.save(config.checkpoint_dir, counter)

        else:
            print("Testing...")
            # print(train_data.shape)
            # print(train_label.shape)
            # print("---------")
            result = self.pred.eval({self.images: train_data, self.labels: train_label})
            # print(result.shape)
            result = merge(result, [nx, ny])
            result = result.squeeze()
            image_path = os.path.join(os.getcwd(), config.sample_dir)
            image_path = os.path.join(image_path, "test_image.png")
            imsave(result, image_path)

    def model(self):
        # input : 输入的要做卷积的图片,要求为一个张量,shape为 [ batch, in_height, in_width, in_channel ],其中batch为图片的数量,in_height 为图片高度,in_width 为图片宽度,in_channel 为图片的通道数,灰度图该值为1,彩色图为3。(也可以用其它值,但是具体含义不是很理解)
        # filter: 卷积核,要求也是一个张量,shape为 [ filter_height, filter_width, in_channel, out_channels ],其中 filter_height 为卷积核高度,filter_width 为卷积核宽度,in_channel 是图像通道数 ,和 input 的 in_channel 要保持一致,out_channel 是卷积核数量。
        # strides: 卷积时在图像每一维的步长,这是一个一维的向量,[ 1, strides, strides, 1],第一位和最后一位固定必须是1
        # padding: string类型,值为“SAME” 和 “VALID”,表示的是卷积的形式,是否考虑边界。"SAME"是考虑边界,不足的时候用0去填充周围,"VALID"则不考虑
        # use_cudnn_on_gpu:  bool类型,是否使用cudnn加速,默认为true
        # padding = “SAME”输入和输出大小关系如下:输出大小等于输入大小除以步长向上取整,s是步长大小;
        # padding = “VALID”输入和输出大小关系如下:输出大小等于输入大小减去滤波器大小加上1,最后再除以步长(f为滤波器的大小,s是步长大小)。

        conv1 = tf.nn.relu(
            tf.nn.conv2d(self.images, self.weights['w1'], strides=[1, 1, 1, 1], padding='VALID',use_cudnn_on_gpu=True) + self.biases['b1'])
        conv2 = tf.nn.relu(
            tf.nn.conv2d(conv1, self.weights['w2'], strides=[1, 1, 1, 1], padding='VALID',use_cudnn_on_gpu=True) + self.biases['b2'])
        conv3 = tf.nn.conv2d(conv2, self.weights['w3'], strides=[1, 1, 1, 1], padding='VALID',use_cudnn_on_gpu=True) + self.biases['b3']

        return conv3

    def save(self, checkpoint_dir, step):
        model_name = "SRCNN.model"
        model_dir = "%s_%s" % ("srcnn", self.label_size)
        # 目录
        checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
        # 不存在就新建
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        # 保存
        # 参数
        '''
        sess,
        save_path,
        global_step=None,
        latest_filename=None,
        meta_graph_suffix="meta",
        write_meta_graph=True,
        write_state=True,
        strip_default_attrs=False,
        save_debug_info=False)
        '''
        self.saver.save(self.sess,
                        os.path.join(checkpoint_dir, model_name),
                        global_step=step)

    def load(self, checkpoint_dir):
        print(" [*] Reading checkpoints...")
        model_dir = "%s_%s" % ("srcnn", self.label_size)
        # 加载模型
        checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
        # 通过checkpoint文件找到模型文件名
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)

        if ckpt and ckpt.model_checkpoint_path:
            # 返回path最后的文件名。如果path以/或\结尾,那么就会返回空值。即os.path.split(path)的第二个元素。
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
            # 加载成功
            return True
        else:
            # 加载失败
            return False




  • utils.py 配置工具函数
"""
Scipy version > 0.18 is needed, due to 'mode' option from scipy.misc.imread function
"""

import os
import glob
import h5py
import random
import matplotlib.pyplot as plt

from PIL import Image  # for loading images as YCbCr format
import scipy.misc
import scipy.ndimage
import numpy as np

import tensorflow as tf

try:
    xrange
except:
    xrange = range

FLAGS = tf.app.flags.FLAGS


def read_data(path):
    """
    Read h5 format data file

    Args:
      path: file path of desired file
      data: '.h5' file format that contains train data values
      label: '.h5' file format that contains train label values
    """
    with h5py.File(path, 'r') as hf:
        data = np.array(hf.get('data'))
        label = np.array(hf.get('label'))
        return data, label


def preprocess(path, scale=3):
    """
    Preprocess single image file
      (1) Read original image as YCbCr format (and grayscale as default)
      (2) Normalize
      (3) Apply image file with bicubic interpolation

    Args:
      path: file path of desired file
      input_: image applied bicubic interpolation (low-resolution)
      label_: image with original resolution (high-resolution)
    """
    # 读取灰度图
    image = imread(path, is_grayscale=True)
    label_ = modcrop(image, scale)

    # Must be normalized
    # 归一化
    image = image / 255.
    label_ = label_ / 255.
    # zoom:类型为float或sequence,沿轴的缩放系数。 如果float,每个轴的缩放是相同的。 如果sequence,zoom应包含每个轴的一个值。
    # output:放置输出的数组,或返回数组的dtype
    # order:样条插值的顺序,默认为3.顺序必须在0-5范围内。
    # prefilter: bool, optional 。参数预滤波器确定输入是否在插值之前使用spline_filter进行预过滤(对于 > 1
    # 的样条插值所必需的)。 如果为False,则假定输入已被过滤。 默认为True。
    input_ = scipy.ndimage.interpolation.zoom(input=label_,zoom=(1. / scale), prefilter=False)
    input_ = scipy.ndimage.interpolation.zoom(input=input_,zoom=(scale / 1.), prefilter=False)

    return input_, label_


def prepare_data(sess, dataset):
    """
    Args:
      dataset: choose train dataset or test dataset
      For train dataset, output data would be ['.../t1.bmp', '.../t2.bmp', ..., '.../t99.bmp']
    dataset:
        "Train" or "Test":to choose the data is train or test
    """
    if FLAGS.is_train:
        filenames = os.listdir(dataset)
        #  获取数据目录
        data_dir = os.path.join(os.getcwd(), dataset)
        data = glob.glob(os.path.join(data_dir, "*.bmp"))
    else:
        # 获取测试集路径
        data_dir = os.path.join(os.sep, (os.path.join(os.getcwd(), dataset)), "Set5")
        data = glob.glob(os.path.join(data_dir, "*.bmp"))
    # 返回文件目录
    return data


def make_data(sess, data, label):
    """
    Make input data as h5 file format
    Depending on 'is_train' (flag value), savepath would be changed.
    """
    if FLAGS.is_train:
        savepath = os.path.join(os.getcwd(), 'checkpoint/train.h5')
    else:
        savepath = os.path.join(os.getcwd(), 'checkpoint/test.h5')

    with h5py.File(savepath, 'w') as hf:
        hf.create_dataset('data', data=data)
        hf.create_dataset('label', data=label)


def imread(path, is_grayscale=True):
    """
    Read image using its path.
    Default value is gray-scale, and image is read by YCbCr format as the paper said.
    """
    if is_grayscale:
        return scipy.misc.imread(path, flatten=True, mode='YCbCr').astype(np.float)
    else:
        return scipy.misc.imread(path, mode='YCbCr').astype(np.float)


def modcrop(image, scale=3):
    """
    To scale down and up the original image, first thing to do is to have no remainder while scaling operation.

    We need to find modulo of height (and width) and scale factor.
    Then, subtract the modulo from height (and width) of original image size.
    There would be no remainder even after scaling operation.
    要缩小和放大原始图像,首先要做的是在缩放操作时没有剩余。
    我们需要找到高度(和宽度)和比例因子的模。
    然后,从原始图像的高度(和宽度)中减去模。
    即使经过缩放操作,也不会有余数。
    """
    if len(image.shape) == 3:
        # 取整
        h, w, _ = image.shape
        h = h - np.mod(h, scale)
        w = w - np.mod(w, scale)
        image = image[0:h, 0:w, :]
    else:
        h, w = image.shape
        h = h - np.mod(h, scale)
        w = w - np.mod(w, scale)
        image = image[0:h, 0:w]
    return image


def input_setup(sess, config):
    """
    Read image files and make their sub-images and saved them as a h5 file format.
    """
    # Load data path
    if config.is_train:

        data = prepare_data(sess, dataset="Train")
    else:
        data = prepare_data(sess, dataset="Test")

    sub_input_sequence = []
    sub_label_sequence = []
    # 计算padding
    padding = abs(config.image_size - config.label_size) / 2  # 6

    if config.is_train:
        for i in xrange(len(data)):
            # TODO 获取原图和低分辨率还原标签
            input_, label_ = preprocess(data[i], config.scale)
            if len(input_.shape) == 3:
                h, w, _ = input_.shape
            else:
                h, w = input_.shape

            for x in range(0, h - config.image_size + 1, config.stride):
                for y in range(0, w - config.image_size + 1, config.stride):
                    sub_input = input_[x:x + config.image_size, y:y + config.image_size]  # [33 x 33]
                    sub_label = label_[x + int(padding):x + int(padding) + config.label_size,
                                y + int(padding):y + int(padding) + config.label_size]  # [21 x 21]

                    # Make channel value
                    sub_input = sub_input.reshape([config.image_size, config.image_size, 1])
                    sub_label = sub_label.reshape([config.label_size, config.label_size, 1])

                    sub_input_sequence.append(sub_input)
                    sub_label_sequence.append(sub_label)
    else:
        input_, label_ = preprocess(data[1], config.scale)
        if len(input_.shape) == 3:
            h, w, _ = input_.shape
        else:
            h, w = input_.shape

        # Numbers of sub-images in height and width of image are needed to compute merge operation.
        nx = ny = 0
        for x in range(0, h - config.image_size + 1, config.stride):
            # 保存索引
            nx += 1
            ny = 0
            for y in range(0, w - config.image_size + 1, config.stride):
                ny += 1
                sub_input = input_[x:x + config.image_size, y:y + config.image_size]  # [33 x 33]
                sub_label = label_[x + int(padding):x + int(padding) + config.label_size,
                            y + int(padding):y + int(padding) + config.label_size]  # [21 x 21]

                sub_input = sub_input.reshape([config.image_size, config.image_size, 1])
                sub_label = sub_label.reshape([config.label_size, config.label_size, 1])

                sub_input_sequence.append(sub_input)
                sub_label_sequence.append(sub_label)

    """
    len(sub_input_sequence) : the number of sub_input (33 x 33 x ch) in one image
    (sub_input_sequence[0]).shape : (33, 33, 1)
    """
    # Make list to numpy array. With this transform
    arrdata = np.asarray(sub_input_sequence)  # [?, 33, 33, 1]
    arrlabel = np.asarray(sub_label_sequence)  # [?, 21, 21, 1]
    make_data(sess, arrdata, arrlabel)

    if not config.is_train:
        return nx, ny


def imsave(image, path):
    return scipy.misc.imsave(path, image)


def merge(images, size):
    # 合并图片
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 1))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j * h:j * h + h, i * w:i * w + w, :] = image

    return img




  • 原图
    在这里插入图片描述

  • 效果图

  • 在这里插入图片描述


http://www.kler.cn/a/449979.html

相关文章:

  • 010 Qt_输入类控件(LineEdit、TextEdit、ComboBox、SpinBox、DateTimeEdit、Dial、Slider)
  • 申请腾讯混元的API Key并且使用LobeChat调用混元AI
  • 使用strimzi-kafka-operator 的mirrormake2(mm2)迁移kafka集群,去掉目标集群的topic默认前缀
  • 基于Spring Boot的九州美食城商户一体化系统
  • Oracle Database 21c Express Edition数据库 和 Sqlplus客户端安装配置
  • 华院计算参与项目再次被《新闻联播》报道
  • Visual Studio 、 MSBuild 、 Roslyn 、 .NET Runtime、SDK Tools之间的关系
  • 【Java基础面试题022】什么是Java内部类?有什么作用?
  • Qt笔记-Qt Creator开发环境搭建
  • C#(委托)2
  • 放弃机器学习框架,如何用Python做物体检测?
  • 监控MySQL数据表变化:Binlog的重要性及实践
  • 自建MD5解密平台-续
  • mysql中局部变量_MySQL中变量的总结
  • 【YashanDB知识库】Oracle pipelined函数在YashanDB中的改写
  • 蓝桥杯练习生第四天
  • Blazor 直接读取并显示HTML 文件内容
  • VSCode如何修改默认扩展路径和用户文件夹目录到D盘
  • 关于mac—address
  • linux安装宝塔面板及git
  • 基于Spring Boot的个性化推荐外卖点餐系统
  • HarmonyOS(72)事件拦截处理详解
  • Certifying LLM Safety against Adversarial Prompting
  • 网络管理 详细讲解
  • 网络安全(一)主动攻击之DNS基础和ettercap实现DNS流量劫持
  • BOE(京东方)“向新2025”年终媒体智享会落地成都 持续创新引领产业步入高价值增长新纪元