当前位置: 首页 > article >正文

Zero-Shot Noise2Noise: Efficient Image Denoising without any Data 笔记

CVPR 2023【已开源】| Zero-Shot Noise2Noise:嗨,模糊,再见!小网络零样本去噪实践!_哔哩哔哩_bilibili

最近,自监督神经网络显示出出色的图像去噪性能。然而,当前的无数据集方法要么计算量大,需要噪声模型,要么图像质量不足。在这项工作中,我们展示了一个简单的 2 层网络,无需任何训练数据或噪声分布知识,就可以以较低的计算成本实现高质量的图像去噪。我们的方法受到 Noise2Noise 和 Neighbor2Neighbor 的启发,并且适用于逐像素独立噪声的去噪。我们对人工、真实世界相机和显微镜噪声的实验表明,我们称为 ZS-N2N(零散粒噪声2噪声)的方法通常以更低的成本优于现有的无数据集方法,使其适用于数据可用性稀缺且受限的用例计算资源。可以在下面找到我们的实现演示,包括我们的代码和超参数。

论文标题:Zero-Shot Noise2Noise: Efficient Image Denoising without any Data

论文链接:https://arxiv.org/pdf/2303.11253.pdf

论文代码:https://colab.research.google.com/drive/1i82nyizTdszyHkaHBuKPbWnTzao8HF9b?usp=sharing

本文的工作是改进了N2N和NB2NB,仅仅使用单个噪声图片进行训练,为了避免单个图像的过拟合,只使用了一个非常浅的网络和一个明确的正则化项来避免该现象。

如今几乎所有的监督式或者非监督去噪方式,包括本文提出的方法,都依赖于一个前提,就是干净的自然图片与具有不同分布方式的噪声图片,噪声图像可以分解为一对下采样图像。基于干净图像附近高像素具有高相似性,相关性的值,而噪声像素是非结构化且独立的,对噪声图片进行下采样后的特征图具有相似的特征,但是噪声依然独立,因此,这一对噪声图片可以作为同一个场景的两个噪声观测值的近似值,其中一个观测值可以用作输入值,另一个用作目标。

本文的方法是首先将图像分解为一对下采样图像,然后训练一个带有正则化的轻量级网络,将一个下采样图像映射到另一个。将如此训练的网络应用于噪声图像产生去噪图像。

图2 图像对下采样器通过对2 × 2非重叠补丁的对角线像素求平均值,将图像分解为两个空间分辨率为一半的图像

使用UNet而不是轻量级网络会导致过拟合和更差的去噪性能。

Zero-Shot Noise2Noise: Efficient Image Denoising without any Data


这个仓库展示了如何使用我们提出的ZS-N2N(零样本Noise2Noise)算法在没有任何训练数据或噪声模型或噪声水平输入的情况下对图像进行去噪。该笔记本基于Pytorch。
该方法非常简单,只需将含噪图像与两个固定的卷积核进行卷积,以生成一对降采样图像。然后使用一致性损失训练一个简单的两层CNN,将一个降采样图像映射到另一个降采样图像。
该笔记本可以轻松地在GPU或CPU上运行。如果你想在GPU上运行,请选择设备为'cuda',否则选择'cpu'。如果你选择了'cuda',请记得通过'运行时' -> '更改运行时类型'将你的Colab运行时更改为GPU。


建议将图像的像素范围归一化到0-1,因为这是网络的超参数和初始化所测试的范围。
该网络的大小是针对256x256尺寸的图像进行优化的。如果你的图像更小或更大,你可能需要减小或增加网络的大小。这可以通过更改变量"chan_embed"来实现。


来自SIDD和PolyU的20张测试图像可以在这里找到:https://drive.google.com/file/d/1oJ2wIMPAxK353Shz5kW0Xg-KXmwUwYZR/ . 这些图像已经进行了归一化处理并转换为PyTorch张量,可以通过`torch.load('文件名')`加载。


[ ]

#Enter device here, 'cuda' for GPU, and 'cpu' for CPU
device = 'cuda'

[ ]

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

[ ]

# Load test image from URL

import requests
from io import BytesIO

# 4 test images from the Kodak24 dataset

#url = "https://drive.google.com/uc?export=download&id=18LcKoV4SYusF16wKqwNBJwYpTXE9myie"
#url = "https://drive.google.com/uc?export=download&id=176lM7ONjvyC83GcllCod-j1RPqjLoRoG"
#url = "https://drive.google.com/uc?export=download&id=1UIh9CwXSCf01JmAXgJo0LPtw5TUkWUU-"
url = "https://drive.google.com/uc?export=download&id=1j1OOzvGhet_GHJCaXbfisiW8uGDxI7ty"

response = requests.get(url)
path=BytesIO(response.content)

clean_img = torch.load(path).unsqueeze(0)
print(clean_img.shape) #B C H W
torch.Size([1, 3, 256, 256])

Add noise to the image


[ ]

noise_type = 'gauss' # Either 'gauss' or 'poiss'
noise_level = 25     # Pixel range is 0-255 for Gaussian, and 0-1 for Poission

def add_noise(x,noise_level):

    if noise_type == 'gauss':
        noisy = x + torch.normal(0, noise_level/255, x.shape)
        noisy = torch.clamp(noisy,0,1)

    elif noise_type == 'poiss':
        noisy = torch.poisson(noise_level * x)/noise_level

    return noisy

noisy_img = add_noise(clean_img, noise_level)

[ ]

clean_img = clean_img.to(device)
noisy_img = noisy_img.to(device)

We next define our network, which is a 2 layer CNN


[ ]

class network(nn.Module):
    def __init__(self,n_chan,chan_embed=48):
        super(network, self).__init__()

        self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        self.conv1 = nn.Conv2d(n_chan,chan_embed,3,padding=1)
        self.conv2 = nn.Conv2d(chan_embed, chan_embed, 3, padding = 1)
        self.conv3 = nn.Conv2d(chan_embed, n_chan, 1)

    def forward(self, x):
        x = self.act(self.conv1(x))
        x = self.act(self.conv2(x))
        x = self.conv3(x)

        return x

n_chan = clean_img.shape[1]
model = network(n_chan)
model = model.to(device)
print("The number of parameters of the network is: ",  sum(p.numel() for p in model.parameters() if p.requires_grad))
The number of parameters of the network is:  22275

图像对下采样器

接下来,我们介绍图像对下采样器,它通过在不重叠的块中平均对角像素来输出两个空间分辨率减半的下采样图像,如下图所示。
这是通过对图像与两个固定卷积核进行卷积实现的:k1=[0 0.5 0.5 0] 和 k2=[0.5 0 0 0.5],其中卷积的步长为2,并且分别应用于每个图像通道。



[ ]

def pair_downsampler(img):
    #img has shape B C H W
    c = img.shape[1]

    filter1 = torch.FloatTensor([[[[0 ,0.5],[0.5, 0]]]]).to(img.device)
    filter1 = filter1.repeat(c,1, 1, 1)

    filter2 = torch.FloatTensor([[[[0.5 ,0],[0, 0.5]]]]).to(img.device)
    filter2 = filter2.repeat(c,1, 1, 1)

    output1 = F.conv2d(img, filter1, stride=2, groups=c)
    output2 = F.conv2d(img, filter2, stride=2, groups=c)

    return output1, output2

显示含噪图像及其对应的下采样图像对。请注意,下采样图像的空间分辨率是原图的一半。


[ ]

img1, img2 = pair_downsampler(noisy_img)

img0 = noisy_img.cpu().squeeze(0).permute(1,2,0)
img1 = img1.cpu().squeeze(0).permute(1,2,0)
img2 = img2.cpu().squeeze(0).permute(1,2,0)

fig, ax = plt.subplots(1, 3,figsize=(15, 15))

ax[0].imshow(img0)
ax[0].set_title('Noisy Img')

ax[1].imshow(img1)
ax[1].set_title('First downsampled')

ax[2].imshow(img2)
ax[2].set_title('Second downsampled')

损失

损失函数由残差损失一致性损失两部分组成,具体如下:
 


其中,( y ) 是含噪声的输入图像,( D ) 是图像对的下采样器,而 (fθ ) 是网络。


[ ]

def mse(gt: torch.Tensor, pred:torch.Tensor)-> torch.Tensor:
    loss = torch.nn.MSELoss()
    return loss(gt,pred)

def loss_func(noisy_img):
    noisy1, noisy2 = pair_downsampler(noisy_img)

    pred1 =  noisy1 - model(noisy1)
    pred2 =  noisy2 - model(noisy2)

    loss_res = 1/2*(mse(noisy1,pred2)+mse(noisy2,pred1))

    noisy_denoised =  noisy_img - model(noisy_img)
    denoised1, denoised2 = pair_downsampler(noisy_denoised)

    loss_cons=1/2*(mse(pred1,denoised1) + mse(pred2,denoised2))

    loss = loss_res + loss_cons

    return loss

Train, test, and denoise functions


[ ]

def train(model, optimizer, noisy_img):

    loss = loss_func(noisy_img)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()

def test(model, noisy_img, clean_img):

    with torch.no_grad():
        pred = torch.clamp(noisy_img - model(noisy_img),0,1)
        MSE = mse(clean_img, pred).item()
        PSNR = 10*np.log10(1/MSE)

    return PSNR

def denoise(model, noisy_img):

    with torch.no_grad():
        pred = torch.clamp( noisy_img - model(noisy_img),0,1)

    return pred

Optimizer and Hyperparameters


[ ]

max_epoch = 3000     # training epochs
lr = 0.001           # learning rate
step_size = 1000     # number of epochs at which learning rate decays
gamma = 0.5          # factor by which learning rate decays

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

Start denoising


[ ]

for epoch in tqdm(range(max_epoch)):
    train(model, optimizer, noisy_img)
    scheduler.step()

PSNR of denoised image


[ ]

PSNR = test(model, noisy_img, clean_img)
print(PSNR)
31.17408216697431

Display clean, noisy, and denoised images.


[ ]

denoised_img = denoise(model, noisy_img)

denoised = denoised_img.cpu().squeeze(0).permute(1,2,0)
clean = clean_img.cpu().squeeze(0).permute(1,2,0)
noisy = noisy_img.cpu().squeeze(0).permute(1,2,0)

[ ]

fig, ax = plt.subplots(1, 3,figsize=(15, 15))
ax[0].imshow(clean)
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[0].set_title('Ground Truth')
ax[1].imshow(noisy)
ax[1].set_xticks([])
ax[1].set_yticks([])
ax[1].set_title('Noisy Img')
noisy_psnr = 10*np.log10(1/mse(noisy_img,clean_img).item())
ax[1].set(xlabel= str(round(noisy_psnr,2)) + ' dB')
ax[2].imshow(denoised)
ax[2].set_xticks([])
ax[2].set_yticks([])
ax[2].set_title('Denoised Img')
ax[2].set(xlabel= str(round(PSNR,2)) + ' dB')


http://www.kler.cn/a/515042.html

相关文章:

  • npm install 报错:Command failed: git checkout 2.2.0-c
  • 【玩转全栈】---基于YOLO8的图片、视频目标检测
  • yolov11 pose 推理代码
  • 微前端qiankun的基本使用(vue-element-admin作为项目模版)
  • 为什么redis会开小差?Redis 频繁异常的深度剖析与解决方案
  • 【达梦数据库】两地三中心环境总结
  • NHANES指标推荐:TyG!
  • 2.复写零
  • Vue3 中使用组合式API和依赖注入实现自定义公共方法
  • 洛谷P8195
  • c++算法贪心系列
  • 2024.1.22 安全周报
  • 大华Java开发面试题及参考答案 (下)
  • UE5 开启“Python Remote Execution“
  • 解决go.mod文件中replace不生效的问题
  • Mono里运行C#脚本31—mono_arch_create_generic_trampoline
  • YOLOv10-1.1部分代码阅读笔记-predictor.py
  • 【Linux】APT 密钥管理迁移指南:有效解决 apt-key 弃用警告
  • 如何实现亿级用户在线状态统计?
  • .NET MAUI进行UDP通信(二)
  • 吴恩达深度学习——如何实现神经网络
  • 【2024年华为OD机试】 (E卷,100分) - 预订酒店(JavaScriptJava PythonC/C++)
  • Ubuntu20彻底删除MySQL8
  • WPS计算机二级•幻灯片的基础操作
  • Qt调用ffmpeg库实时播放rtmp或rtsp视频流
  • 【玩转全栈】----Django模板的继承