当前位置: 首页 > article >正文

【论文阅读笔记】IceNet算法与代码 | 低照度图像增强 | IEEE | 2021.12.25

03844589850a4de3a0f5e732bf03cf47.png

目录

1 导言

2 相关工作

A 传统方法

B 基于CNN的方法

C 交互方式

3 算法

A 交互对比度增强

1)Gamma estimation

2)颜色恢复

3)个性化初始η

B 损失函数

1)交互式亮度控制损失

2)熵损失

3)平滑损失

4)总损失

C 实现细节

4 实验

5 IceNet环境配置和运行

1 下载代码

2 运行

3 训练



e4666092054c4fcda5533f3ed22dd2d3.png

💚论文题目IceNet for Interactive Contrast Enhancement

💜论文地址:https://export.arxiv.org/pdf/2109.05838v2.pdf

💙代码地址:https://github.com/keunsoo-ko/IceNet

【摘要】本文提出了一种基于CNN的交互式对比度增强(contrast enhancement)算法IceNet,使得用户可以根据自己的喜好轻松地调整图像对比度。具体来说,用户提供一个参数来控制全局亮度和两种类型的涂鸦来使图像中的局部区域变暗或变亮。然后,给定这些注释,IceNet估计一个伽马图,用于逐像素的伽马校正。最后,通过颜色恢复,得到增强后的图像。用户可以迭代地提供注释,以获得满意的图像。IceNet还可以自动生成个性化的增强图像,如果需要,可以作为进一步调整的基础。此外,为了有效且可靠地训练IceNet,我们提出了三种可微的损失函数。大量实验表明,IceNet可以为用户提供令人满意的增强图像。

索引词:交互式对比度增强、个性化对比度增强、卷积神经网络、自适应伽马校正。

1 导言

🦋🦋🦋尽管近年来成像技术取得了很大的进步,但由于曝光不均匀、快门周期短、环境光弱等因素的影响,照片往往不能忠实地再现场景细节。这类照片表现出对比度失真、色彩褪色、强度低等现象;特别是,异常的光照条件会显著地扭曲颜色、纹理和物体,从而降低视觉体验。对比增强( CE )技术可以缓解这些问题。

🦋🦋🦋为了开发有效的CE技术,人们进行了许多尝试。一种简单而有效的方法是从输入到输出像素值的转换函数使用参数曲线。例如,一种权力法则是伽马校正中著名的参数曲线。这些参数曲线产生了很有前途的结果,但找到可靠的参数是很有挑战性的,这些参数对不同的图像都是有效的。最近,随着卷积神经网络( CNN )在低级视觉领域的成功,基于CNN的CE方法也被提出,并取得了出色的性能。

🦋🦋🦋为了实现CE的自动化,人们进行了大量的研究。然而,CE是一个主观的过程,因为人们对对比度有不同的偏好。值得注意的是,对比度在图像解读中发挥着不同的作用。例如,观看者使用对比来吸引注意力;艺术家通过对比传达强调;平面设计师通过对比来告诉眼睛该往何处去。因此,增强对比度没有唯一确定的方法。图1说明了可以有多种方法来增强同一幅图像的对比度。专业软件,如Photoshop,提供了许多根据个人喜好调整对比度的工具,但使用这些工具需要花费很多精力。在本文中,我们提出了第一个基于CNN的交互式CE算法,使用户能够方便地自适应地调整图像对比度。

图1。使用IceNet对同一幅图像进行四种不同的CE结果。

每个结果可以有不同的解释:( a )高暴露水平的温暖走廊( b )中等暴露水平的凉爽走廊( c )和( d )相同低暴露水平的黑暗走廊。在( d )中,前景表和陶瓷被高亮显示,具有局部亮化

bb5f088fcec34c31b616e0bf348d1a55.png

🌸🌸对于交互式CE,我们提出了IceNet,它可以在接受用户的注释后增强图像的欠曝光区域和过曝光区域的对比度。具体来说,用户提供一个参数来控制输入图像中的全局亮度和两种类型的涂鸦,以使输入图像中的局部区域变暗或变亮。然后,我们将输入图像和标注输入到IceNet中,IceNet生成伽马图,用于逐像素的伽马校正。更具体地说,我们将涂鸦和输入图像一起送入特征提取器,产生特征图。同时,我们从全局亮度参数中产生一个驱动向量。然后,利用特征图和驱动向量,我们得到根据用户标注自适应生成伽马图。最后,我们通过颜色恢复恢复一幅增强图像,其中我们只调整输入图像的亮度分量,同时保留色度分量。此外,为了用户的方便,我们提供了一个初步增强的图像,并允许用户以交互式的方式进一步增强它。此外,为了有效且可靠地训练IceNet,我们提出了三个可微的损失函数。实验结果表明,IceNet不仅为用户提供了满意的图像,而且在定性和定量上都优于当前最先进的算法。强烈建议观看IceNet实时演示的补充视频。

💕💕💕总的来说,这项工作有以下主要贡献:

提出了第一个基于CNN的交互式CE算法,称为IceNet它能够根据用户偏好自适应地产生增强图像或自动地不进行交互

使用提出的三个可微的损失函数来训练IceNet,从而实现与用户的交互,并产生令人满意的增强图像。

通过各种实验结果,我们证明了IceNet可以为用户提供令人满意的结果,优于传统的算法。

2 相关工作

🦋🦋🦋CE,颜色增强,去雾和细节增强之间的增强目标是密切相关但又不同的。本部分简要回顾了CE技术。

A 传统方法

传统的CE技术可以分为三类:直方图方法参数曲线方法retinex方法。首先,直方图均衡化试图使输出直方图尽可能均匀。它以较低的计算复杂度有效地增加了对比度,但它可能会过度增强图像,导致对比度超张拉、噪声放大和轮廓伪影。为了克服这些问题,各种直方图方法被开发出来。其次,参数曲线方法,如伽马校正和对数映射,使用参数曲线作为输入和输出像素值之间的转换函数。其中,伽马校正不仅被广泛用于CE,而且还被用于匹配成像设备之间的不同动态范围。第三,假设一幅图像可以分解为反射和照明,一些基于Retinex理论的CE算法已经被提出。这些传统方法在某些情况下产生了很有前途的结果。然而,它们的性能通常依赖于仔细的参数调整;对于不同的输入图像,很难找到可靠有效的参数

B 基于CNN的方法

最近,许多CNN被设计用于CE。大多数基于CNN的方法都是在成对的数据集上训练的,由成对的低对比度和高对比度图像组成。每对图像从同一场景中采集,或者从一幅高对比度图像合成一幅低对比度图像。然而,由于运动物体的存在,很难对同一场景进行两次拍摄。另一方面,合成的低对比度图像可能不具有真实感。由于这些困难,一些方法采用生成对抗网络和使用未配对的数据集训练他们的网络,这些数据集由从不同场景捕获的低对比度和高对比度图像组成。然而,他们应该谨慎地选择未配对的图像。为了避免这个繁琐的过程,Guo等人提出了一种自监督学习方案,只需要低对比度图像进行训练。然而,所有这些基于CNN的方法都不具有自适应性,无法满足多样化的用户偏好。

C 交互方式

对于交互式CE,专业软件提供了工具,但是使用这些工具需要花费大量的精力和训练。为了减少这种花费,提出了简单的交互方式。Stoel等人提出了一种交互式直方图均衡化方案来增强医学图像的细节。他们的方法允许用户指定一个感兴趣区域( RoI ),并将均衡应用于该区域。格伦德兰和道奇森提出了一种交互式的音调调整方法。当用户在图像上选择关键色调时,它保留这些色调但调整其他色调,同时保持整体色调平衡。Lischinski等人和道奇森等人也提出了交互式方法,用户指定本地CE的RoI。这些交互方法都是基于传统的图像处理。相比之下,据我们所知,提出的IceNet是第一个基于CNN的交互式CE算法

3 算法

🦋🦋🦋为了使用户能够方便地调整图像对比度,我们开发了IceNet,它可以接受简单的注释,并根据用户的喜好生成增强图像。

A 交互对比度增强

图2说明了所提出的算法。

首先,通过检查输入图像I,用户提供一个曝光水平η∈[ 0、1 ],用于控制全局亮度和两种类型的涂鸦。红色和蓝色的涂抹分别意味着用户希望相应的局部区域变暗和变亮。它们在涂抹图S中分别记为- 1和1,而其他像素被赋值为0

之后,将RGB彩色图像I转换到YCbCr空间,只调整亮度分量Y,同时保留色度分量。

然后,估计了一个伽马映射Γ来对Y进行逐像素的伽马校正。

最后,通过色彩恢复,得到增强后的图像J。

6f3a6964055e4e55b8c16065825fde54.png

核心代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

class IceNet(nn.Module):

	def __init__(self):
		super(IceNet, self).__init__()

		self.relu = nn.ReLU(inplace=True)
		self.e_conv1 = nn.Conv2d(2,32,3,1,1,bias=True) 
		self.e_conv2 = nn.Conv2d(32,32,3,1,1,bias=True) 
		self.e_conv3 = nn.Conv2d(32,32,3,1,1,bias=True) 
		self.e_conv4 = nn.Conv2d(32,32,3,1,1,bias=True) 
		self.e_conv5 = nn.Conv2d(64,32,3,1,1,bias=True) 
		self.e_conv6 = nn.Conv2d(64,32,3,1,1,bias=True) 
		self.e_conv7 = nn.Conv2d(64,32,3,1,1,bias=True)

		self.fc1 = nn.Linear(1, 32)
		self.fc2 = nn.Linear(32, 32)


		
	def forward(self, y, maps, e, lowlight=None, is_train=False):
		b, _, h, w = y.shape
		x_ = torch.cat([y, maps], 1)

		# generate adaptive vector according to eta
		W = self.relu(self.fc1(e))
		W = self.fc2(W)

		# feature extractor
		x1 = self.relu(self.e_conv1(x_))
		x2 = self.relu(self.e_conv2(x1))
		x3 = self.relu(self.e_conv3(x2))
		x4 = self.relu(self.e_conv4(x3))
		x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1)))
		x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))
		x_r = self.relu(self.e_conv7(torch.cat([x1,x6],1)))

		# AGEB
		x_r = F.conv2d(x_r.view(1, b * 32, h, w),
				W.view(b, 32, 1, 1), groups=b)
		x_r = torch.sigmoid(x_r).view(b, 1, h, w) * 10

		# gamma correction
		enhanced_Y = torch.pow(y,x_r)
		if is_train:
			return enhanced_Y, x_r
		else:
			# color restoration
			enhanced_image = torch.clip(enhanced_Y*(lowlight/y), 0, 1)
			return enhanced_image

1)Gamma estimation

🐬伽马校正广泛应用于CE [ 10 ]。选择一个合适的伽玛值是很重要的:小于1的小伽玛使曝光不足的区域变亮,而大于1的大伽玛使曝光过度的区域变暗。伽马值也应考虑个人偏好进行选择。

🐬我们根据用户偏好,通过特征提取器和图2中的自适应伽马估计块( AGEB )为图像中的每个像素确定一个伽马值。特征提取器由7个卷积层组成,层与层之间采用级联的跳跃连接,提高了层与层之间的信息流动性。我们将Y和S串联起来构成特征提取器的输入,从而得到特征图F。然后,给定曝光水平η和特征图F,AGEB生成伽马图Γ。更具体地说,从η出发,AGEB通过两个全连接层产生驱动向量w。然后,对每个像素x,预测像素级伽马值Γ ( x )如公式(1)

c8e9a166bbb74673a6dacbeaecbea5fd.png

其中,〈·,·〉是内积,φ ( · )是sigmoid函数.因此,我们有0 < Γ ( X ) < 10 .这对于所有像素都是重复的。

2)颜色恢复

对于颜色恢复,我们首先使用其伽马值对亮度图像Y中的每个像素进行伽马校正,得到亮度图像Z。如下述公式(2)。

2fc557cac506473f98350214862e94c8.png

其中Ymax为最大亮度值(通常为255 )。然后,为了生成增强的RGB彩色图像J,我们采用简单的颜色恢复方法。它保留了颜色比率。如下述公式(3)。

e838e855b3b44d028a4b027637cfeb95.png

3)个性化初始η

为了方便用户使用,我们提供了一个初步增强的图像,并允许用户通过交互来进一步增强它。为此,我们使用一个二次多项式来估计初始曝光水平η 。如下述公式(4)。

78b2397200234e5382c8fae7655e23e7.png

其中,y = 1/N∑x Y ( x )为输入图像的平均亮度,N为像素个数。系数a,b和c由观测值利用最小二乘法得到,其中ηi是用户选择的以平均亮度yi对第i幅输入图像进行增强的曝光水平。当M > 3时,我们进行这种个性化。

B 损失函数

为了训练IceNet,我们使用了三个可微的损失:交互式亮度控制损失Libc、熵损失Lent和平滑度损失Lsmo。下面我们来描述这些损失。

1)交互式亮度控制损失

给定曝光水平η和涂鸦图S,构造目标亮度图T,定义交互式亮度控制损失Libc如公式(5)所示。

101b69328bb04eecb2ab514639fe5406.png

因此,式中的增强亮度图像Z 在公式( 2 )符合用户偏好。

为了得到目标亮度T,我们首先在输入亮度Y的基础上加入涂色S。如公式(6)所示。

04ac55909fcc40d2a700f7aeeecee38d.png

其中,λ是控制涂抹影响的参数,固定为5。通过添加S,我们控制了局部的亮度水平。接下来,我们将Y '归一化到[ 0、1 ]的范围内,得到Y ~。然后,我们使用双边伽马调整方案来提高黑暗和明亮区域细节的可见性。

c952c0fb60db4ad3be2d0349ba67cc5a.png

其中γ = 5 .值得注意的是,暗区和亮区主要通过Eq( 7 )和( 8 )得到增强。为了保留暗区和亮区的细节,Eq . ( 9 )综合了两者的结果。用户可以通过曝光等级η来控制G的整体亮度。图3说明了η的影响。

图3 ( a )给出了α = 0 . 2、0 . 4、0 . 6和0 . 8时,Y ̄( x )到G ( x )的变换函数。所提出的变换函数在最小值和最大值附近具有陡峭的斜率因此它们使IceNet能够有效地增强欠曝光和过曝光区域。图3 ( b )—( e )是增强图像的例子。我们发现,随着η的增大,增强后的图像J变得更加明亮

d0bd209aa45343fcb55c813f03857875.png

接下来,对于每个像素x,我们计算一个局部最大值,并按比例缩放,如公式(10)所示。

1eefb1c7953d4e03a40bc8bc159ebe30.png

其中Ω ( x )是关于x的15 × 15窗口.在公式( 5 )中,我们用T代替Ymax × G来抑制相邻像素之间过于不同的目标值,实现更可靠的增强。

2)熵损失

当图像的像素值分布越广泛时,图像的熵越大,传递的信息越多。特别地,最大熵是通过均匀分布来实现的。因此,我们采用熵损失,通过均衡输出图像的直方图来增加全局对比度。然而,直方图不能直接使用,因为其不可区分性。为了克服这个问题,我们设计了一个软直方图。虽然一个像素只对普通直方图的单个bin有贡献,但它以可微的方式对软直方图中的多个bin有贡献。

对于软直方图,我们使用映射函数来表示像素x对第i个bin的贡献,如公式(11)所示。

11635013c96a426e9b317885bcd1f889.png

图4 ( a )给出了σ = 5,10,20时的软映射函数。值得注意的是,σ控制着映射函数的宽度和高度之间的权衡。随着σ的增大,映射函数越来越窄但越来越高

c72bb6e81e1a44a5a156a9c1c164bbdb.png

在这项工作中,我们设定σ = 10。接下来,我们通过总结所有像素的贡献来制作软直方图h。如公式(12)所示。

19c766eea39b4fdb99dd77e07a3aabb6.png

然后,我们定义Lent为熵的逆,如公式(13)所示。

f3c36b76be564206af599ccc49b94be0.png

3)平滑损失

为了鼓励方程中伽马映射Γ的相邻值之间的平滑变化。我们引入了光滑性损失,如公式(14)所示。

2ea670ff6260439ab655375cc13ce99c.png

式中,‖·‖F为Frobenius范数。

4)总损失

总体损失定义为三个损失的加权和,如公式(15)所示。

f1876dd037f34ce5ac0c75981f2e1848.png

式中,Went和Wsmo为权重。因此,对提出的IceNet进行训练,使3个损失函数最小。首先,Libc使IceNet能够控制全局和局部亮度。其次,Lent鼓励一个平坦的直方图,这样可以增加全局对比度。第三,Lsmo平滑伽马图。

图4 ( b ) ~ ( e )分别给出了各损失的功效。

cb65f302fc424b3a936dda51dd12645c.png

核心代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torchvision.models.vgg import vgg16
import numpy as np
from numpy.testing import assert_almost_equal

class L_ent(nn.Module):
    def __init__(self, bins, min, max, sigma):
        super(L_ent, self).__init__()
        self.bins = bins
        self.min = min
        self.max = max
        self.sigma = sigma
        self.delta = float(max - min) / float(bins)
        self.centers = float(min) + self.delta * (torch.arange(bins).float().cuda() + 0.5)

    def forward(self, y):
        b, _, h, w = y.shape
        y = y.reshape(b, 1, -1)
        c = self.centers.reshape(1, -1, 1).repeat(b, 1, 1)
        x = y - c
        x = torch.sigmoid(self.sigma * (x + self.delta/2)) - torch.sigmoid(self.sigma * (x - self.delta/2))

        hist = torch.sum(x, 2)
        p = hist / (h * w) + 1e-6
        d = torch.sum((-p * torch.log(p)))
        return 1/d

class L_int(nn.Module):

    def __init__(self):
        super(L_int, self).__init__()

    def forward(self, x, mean_val, labels):
        b,c,h,w = x.shape
        x = torch.mean(x,1,keepdim=True)
        d = torch.mean(torch.pow(x- labels,2))

        return d
        
class L_smo(nn.Module):
    def __init__(self):
        super(L_smo,self).__init__()

    def forward(self,x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h =  (x.size()[2]-1) * x.size()[3]
        count_w = x.size()[2] * (x.size()[3] - 1)
        h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
        w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
        return 2*(h_tv/count_h+w_tv/count_w)/batch_size

C 实现细节

🐬🐬🐬每个卷积层或全连接层的输出通道数为32。在每个卷积层中,执行零填充和ReLU激活。由于采用了8的小批量,因此不采用批归一化处理。我们用高斯随机数初始化IceNet中的所有参数。然后,我们通过Adam优化器以初始学习率0.001更新参数。采用与文献[ 40 ]相同的2002张训练图像。训练图像在SICE [ 49 ]的Part1子集中随机选取,包括欠曝光、正常曝光和过曝光图像。训练在RTX 2080Ti GPU上迭代50次,仅耗时约30分钟。所有训练图像大小调整为512 × 512。为了模拟用户标注,从[ 0.2、0.8]中随机选择一个曝光水平η,分别在随机位置生成0 ~ 5次的红色和蓝色涂鸦。对于这个仿真,每一个划线都是一个半径为10个像素的圆。式( 15 )中,Went = 10,Wsmo = 20。

4 实验

图6 . IceNet与传统算法的定性比较。传统的增强方法在奇数行,而基于CNN的方法在偶数行。

daa2c42adfcd4356bf702563f0fee823.png

5 IceNet环境配置和运行

1 下载代码

git clone https://github.com/keunsoo-ko/IceNet.git

2 运行

原版demo_interactive.py代码如下:

import model
import os
from PIL import Image
from skimage.color import rgb2ycbcr
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

#scribble size
points = 10

# initialize
l_drawing, r_drawing = False, False

def onChange(pos):
    pass

def draw_circle(event,x,y,flags,param):
    global l_drawing, r_drawing

    if event == cv2.EVENT_RBUTTONDOWN:
        r_drawing = True
        l_drawing = False

    elif event == cv2.EVENT_LBUTTONDOWN:
        l_drawing = True
        r_drawing = False

    elif event == cv2.EVENT_MOUSEMOVE:
        if l_drawing == True:
            cv2.circle(inputs, (x,y), 5, (0,0,255), -1)
            cv2.circle(scribble, (x,y), points, 1, -1)
        elif r_drawing == True:
            cv2.circle(inputs, (x,y), 5, (255,0,0), -1)
            cv2.circle(scribble, (x,y), points, -1, -1)

    elif event == cv2.EVENT_LBUTTONUP:
        l_drawing = False
        cv2.circle(inputs, (x,y), 5, (0,0,255), -1)
        cv2.circle(scribble, (x,y), points, 1, -1)

    elif event == cv2.EVENT_RBUTTONUP:
        r_drawing = False
        cv2.circle(inputs, (x,y), 5, (255,0,0), -1)
        cv2.circle(scribble, (x,y), points, -1, -1)

os.environ['CUDA_VISIBLE_DEVICES']="2"

# load image
img = Image.open('/home/ksko/Desktop/Low-light/CVPR2022/BMVC_code/data/test_data/NPE/night (50).jpg')
img = np.asarray(img)

# rgb2y -> Tensor
ycbcr = rgb2ycbcr(img)
y = ycbcr[..., 0] / 255.
y = torch.from_numpy(y).float()
y = y[None, None].cuda()

# rgb -> Tensor
lowlight = torch.from_numpy(img).float()
lowlight = lowlight.permute(2,0,1)
lowlight = lowlight.cuda().unsqueeze(0) / 255.


IceNet = model.IceNet().cuda()
IceNet.load_state_dict(torch.load('model/icenet.pth'))

resume = True


while(resume):
    inputs = img.copy() / 255.
    scribble = np.zeros(inputs.shape[:2])
    
    drawing = False

    cv2.namedWindow('image', cv2.WINDOW_AUTOSIZE | cv2.WINDOW_GUI_NORMAL)
    cv2.setMouseCallback('image', draw_circle)
    cv2.createTrackbar("threshold", "image", 0, 100, onChange)
    cv2.setTrackbarPos("threshold", "image", 60)
    while(1):
        global_e = cv2.getTrackbarPos("threshold", "image") / 100.
        # annotations
        s = torch.from_numpy(scribble)[None, None].float().cuda()
        eta = torch.Tensor([global_e]).float().cuda()
        # feedforward
        enhanced_image = IceNet(y, s, eta, lowlight)
        output = enhanced_image[0].permute(1, 2, 0).cpu().detach().numpy()

        cv2.imshow('image', np.concatenate([inputs, output], 1)[..., ::-1])
        k = cv2.waitKey(1) & 0xFF
	# To reset , push key "1"
        if k == 49:
            resume = True
            break
	# To save results, push key "2"
        if k == 50:
            cv2.imwrite('results/eta_%02d.png'%(global_e*100), output[..., ::-1]*255.)
	# To end demo, push key "Esc"
        if k == 27:
            resume = False
            break
    cv2.destroyAllWindows()

适配CPU版的demo_interactive.py代码如下:

import model
import os
from PIL import Image
from skimage.color import rgb2ycbcr
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

#scribble size
points = 10

# initialize
l_drawing, r_drawing = False, False

def onChange(pos):
    pass

def draw_circle(event,x,y,flags,param):
    global l_drawing, r_drawing

    if event == cv2.EVENT_RBUTTONDOWN:
        r_drawing = True
        l_drawing = False

    elif event == cv2.EVENT_LBUTTONDOWN:
        l_drawing = True
        r_drawing = False

    elif event == cv2.EVENT_MOUSEMOVE:
        if l_drawing == True:
            cv2.circle(inputs, (x,y), 5, (0,0,255), -1)
            cv2.circle(scribble, (x,y), points, 1, -1)
        elif r_drawing == True:
            cv2.circle(inputs, (x,y), 5, (255,0,0), -1)
            cv2.circle(scribble, (x,y), points, -1, -1)

    elif event == cv2.EVENT_LBUTTONUP:
        l_drawing = False
        cv2.circle(inputs, (x,y), 5, (0,0,255), -1)
        cv2.circle(scribble, (x,y), points, 1, -1)

    elif event == cv2.EVENT_RBUTTONUP:
        r_drawing = False
        cv2.circle(inputs, (x,y), 5, (255,0,0), -1)
        cv2.circle(scribble, (x,y), points, -1, -1)

os.environ['CUDA_VISIBLE_DEVICES']="2"

# load image
img = Image.open('img/bus.png')
# 期待输入的图像有3个通道,而不包含透明度信息
# 可以将图像转换为RGB格式,去除Alpha通道
img = img.convert("RGB")
img = np.asarray(img)

# rgb2y -> Tensor
ycbcr = rgb2ycbcr(img)
y = ycbcr[..., 0] / 255.
y = torch.from_numpy(y).float()
y = y[None, None].cpu()

# rgb -> Tensor
lowlight = torch.from_numpy(img).float()
lowlight = lowlight.permute(2,0,1)
lowlight = lowlight.cpu().unsqueeze(0) / 255.


IceNet = model.IceNet().cpu()
IceNet.load_state_dict(torch.load('model/icenet.pth', map_location=torch.device('cpu')))

resume = True


while(resume):
    inputs = img.copy() / 255.
    scribble = np.zeros(inputs.shape[:2])
    
    drawing = False

    cv2.namedWindow('image', cv2.WINDOW_AUTOSIZE | cv2.WINDOW_GUI_NORMAL)
    cv2.setMouseCallback('image', draw_circle)
    cv2.createTrackbar("threshold", "image", 0, 100, onChange)
    cv2.setTrackbarPos("threshold", "image", 60)
    while(1):
        global_e = cv2.getTrackbarPos("threshold", "image") / 100.
        # annotations
        s = torch.from_numpy(scribble)[None, None].float().cpu()
        eta = torch.Tensor([global_e]).float().cpu()
        # feedforward
        enhanced_image = IceNet(y, s, eta, lowlight)
        output = enhanced_image[0].permute(1, 2, 0).cpu().detach().numpy()

        cv2.imshow('image', np.concatenate([inputs, output], 1)[..., ::-1])
        k = cv2.waitKey(1) & 0xFF
	# To reset , push key "1"
        if k == 49:
            resume = True
            break
	# To save results, push key "2"
        if k == 50:
            cv2.imwrite('results/eta_%02d.png'%(global_e*100), output[..., ::-1]*255.)
	# To end demo, push key "Esc"
        if k == 27:
            resume = False
            break
    cv2.destroyAllWindows()

运行demo_interactive.py。

在demo中,按“1”键重置涂鸦图,按“2”键保存增强后的图像,按“Esc”键完成演示。

效果图如下:

4c073d99c502443a9c165ef95691584a.png

3 训练

python Train.py

 

至此,本文的内容就结束啦💕💕💕。

 

 


http://www.kler.cn/a/460140.html

相关文章:

  • Echarts+vue电商平台数据可视化——webSocket改造项目
  • 【OpenCV】使用Python和OpenCV实现火焰检测
  • Python - 游戏:飞机大战;数字华容道
  • 单片机常用外设开发流程(1)(IMX6ULL为例)
  • 数据可视化分析详解
  • LLM(十二)| DeepSeek-V3 技术报告深度解读——开源模型的巅峰之作
  • 我是用git pull每次都要输入账号密码
  • 数据安全技巧:使用私钥认证结合内网穿透实现安全高效的服务器管理
  • 应用层2——FTP文件传输协议
  • QT作业4
  • 一文大白话讲清楚CSS元素的水平居中和垂直居中
  • 浅显易懂的 git 入门
  • 电池放电仪在各领域的作用
  • android——屏幕适配
  • 【Flink运行时架构】系统构架
  • 咚次游戏加速1.1.4.2 | 免费PC游戏加速器,支持1473款游戏加速
  • 如何初始化css样式?为什么要初始化css?
  • python-LeetCode-两数之和
  • Spring Boot缓存预热实战指南
  • .net core 的字符串处理
  • 三大行业案例:AI大模型+Agent实践全景
  • 美畅物联丨视频上云网关获取视频流地址供第三方调用的方法
  • 【论文阅读笔记】IC-Light
  • 电子电器架构 ---什么是智能电动汽车上的逆变器?
  • 极品飞车6的游戏手柄设置
  • 【Leetcode 每日一题 - 扩展】面试题 04.10. 检查子树