当前位置: 首页 > article >正文

计算机视觉语义分割——U-Net(Convolutional Networks for Biomedical Image Segmentation)

计算机视觉语义分割——U-Net(Convolutional Networks for Biomedical Image Segmentation)

文章目录

  • 计算机视觉语义分割——U-Net(Convolutional Networks for Biomedical Image Segmentation)
    • 摘要
    • Abstract
    • 一、U-Net
      • 1. 基本思想
      • 2. 收缩路径(Contracting path)
      • 3. 扩展路径(Expanding path)
      • 4. U-Net代码实践
        • 4.1 数据集选择
        • 4.2 加载数据集
        • 4.3 U-Net模型搭建
        • 4.4 训练U-Net
        • 4.5 预测
      • 5. U-Net与FCN的关系
      • 6. U-Net与生成模型
        • 6.1 Stable Diffusion中的U-Net
        • 6.2 U-Net在生成模型中的改进
        • 6.3 U-Net在生成模型中训练目标的变化
    • 总结

摘要

本周的工作聚焦于U-Net网络在语义分割任务中的研究与实现。U-Net作为一种经典的全卷积神经网络,因其编码器-解码器的对称结构和跳跃连接机制,广泛应用于生物医学图像分割。本次工作详细解析了U-Net的网络结构,包括收缩路径、扩展路径,以及其与FCN的关系。此外,还探讨了U-Net在生成模型中的改进及其在Stable Diffusion中的应用。通过搭建基于PyTorch的U-Net网络,完成了数据预处理、模型训练及预测,并在医学图像分割任务中取得了良好的效果。实验表明,U-Net可以有效减少信息丢失,提升分割精度,为后续生成模型中的应用奠定了基础。

Abstract

This week’s work focused on the study and implementation of the U-Net network in semantic segmentation tasks. U-Net, as a classic fully convolutional neural network, is widely used in biomedical image segmentation due to its symmetrical encoder-decoder structure and skip connection mechanism. This report provides a detailed analysis of the U-Net architecture, including its contracting path, expanding path, and its relationship with FCN. Additionally, improvements to U-Net in generative models and its application in Stable Diffusion were explored. By building a PyTorch-based U-Net network, tasks such as data preprocessing, model training, and prediction were completed, achieving promising results in medical image segmentation. The experiments demonstrate that U-Net effectively reduces information loss and improves segmentation accuracy, laying a solid foundation for its application in generative models.

一、U-Net

1. 基本思想

img

U-Net由 Olaf Ronneberger 等人为生物医学图像分割而开发。该架构包含两条路径。第一条路径是收缩路径(也称为编码器),用于捕获图像中的上下文。编码器只是卷积层和最大池化层的传统堆栈。第二条路径是扩展路径(也称为解码器),用于使用转置卷积实现精确定位。因此,它是一个端到端的全卷积网络 (FCN),即它只包含卷积层,不包含任何密集层,因此它可以接受任何大小的图像。

2. 收缩路径(Contracting path)

收缩路径使用卷积层池化层的组合来提取和捕获图像中的特征,同时减少其空间维度。现在让我们来看看下面收缩路径中的5 个block中的每一个

img

Block 1:

  1. 一张尺寸为572² 的输入图像被输入到 U-Net 中。该输入图像仅包含1 个灰度通道
  2. 然后将两个3x3 卷积层 (无填充)应用于输入图像,每个层后跟一个ReLU层。同时将通道数增加到**64 个,**以便捕获更高级别的特征。
  3. 然后应用步长为2 的2x2 最大池化层。这会将特征图下采样为其大小的一半,即284²

Block 2:

  1. 与Block 1 一样,两个3x3 卷积层(无填充)应用于Block 1 的输出,每个层后又跟有一个ReLU层。在每个新块中,特征通道数量加倍,现在为128
  2. 接下来,再次将2x2 最大池化层应用于生成的特征图,将空间尺寸减少一半140²

Block 3:

  1. 与Block 1和 Block 2中的步骤相似,最后将特征图下采样至68²

Block 4:

  1. 与Block 1和 Block 2中的步骤相似,最后将特征图下采样至32²

Block 5:

  1. 在收缩路径的最后一个块中特征通道数在每个块中翻倍,达到1024个
  2. 此块还包含两个3x3 卷积层(未填充),每个层后面都有一个ReLU层。但是,出于对称目的,我只包含一个层,并将第二层包含在扩展路径中。

在提取出复杂的特征之后,特征图就会进入到扩展路径的运算阶段。

3. 扩展路径(Expanding path)

扩展路径使用卷积转置卷积操作来组合学习到的特征并对输入特征图进行上采样,直到生成分割图。扩展路径与收缩路径非常相似,下面将逐步了解每个Block。

img

注意:上图的灰色箭头代表“跳跃连接”,跳跃连接用于将图像直接从收缩路径发送到扩展路径,而无需经过所有块。这样可以保留和学习高级和低级特征,从而减少收缩路径期间发生的任何信息丢失。

Block 5

  1. 从收缩路径继续,应用第二个3x3 卷积(未填充),其后接着一个ReLU层
  2. 然后应用2x2 卷积(上卷积)层,将空间维度上采样两倍,并将通道数减半为512

Block 4

  1. 然后使用跳跃连接将收缩路径中的相应特征图连接起来,将特征通道数加倍至1024请注意,必须裁剪此连接以匹配扩展路径的尺寸。
  2. 应用两个3x3 卷积层(未填充)每个层后面都有一个ReLU层,将通道数减少到512
  3. 之后,应用2x2 卷积(上卷积)层,将空间维度上采样两倍,并将通道数减半为256

Block 3

  1. 与Block 5和 Block 4中的步骤相似,最后将特征图上采样两倍至200²,并将通道数减半为128

Block 2

  1. 与Block 5和 Block 4中的步骤相似,最后将特征图上采样两倍至392²,并将通道数减半为64

Block 1

  1. 在扩展路径的最后一块中,连接跳跃连接后共有128 个通道。
  2. 接下来,在特征图上应用两个3x3 卷积层(未填充)中间的ReLU层将特征通道数减少到64
  3. 最后,使用1x1 卷积层,然后是激活层(二元分类中的Sigmoid函数),将通道数减少到所需的类别数。在本例中,为 2 个类别,因为二元分类通常用于医学成像。

在扩展路径上对特征图进行上采样后,应该生成分割图,其中每个像素都被单独分类。

4. U-Net代码实践

4.1 数据集选择

在这里插入图片描述在这里插入图片描述

收集的30张果蝇的电镜细胞图,分辨率为512x512,上图第一张为原图,第二张则为标签图。

训练集图片:

image-20250118212927287

训练集标签:

image-20250118212946353

4.2 加载数据集

数据加载要做哪些处理,是根据任务和数据集而决定的,对于我们的分割任务,不用做太多处理,但由于数据量很少,仅30张,我们可以使用一些数据增强方法,来扩大我们的数据集。

Pytorch 给我们提供了一个方法,方便我们加载数据,我们可以使用这个框架,去加载我们的数据。看下伪代码:

# ================================================================== #
#                Input pipeline for custom dataset                 #
# ================================================================== #

# You should build your custom dataset as below.
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self):
        # TODO
        # 1. Initialize file paths or a list of file names. 
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        pass
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return 0 

# You can then use the prebuilt data loader. 
custom_dataset = CustomDataset()
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
                                           batch_size=64, 
                                           shuffle=True)

这是一个标准的模板,我们就使用这个模板,来加载数据,定义标签,以及进行数据增强。

创建一个dataset.py文件,编写代码如下:

import torch
import cv2
import os
import glob
from torch.utils.data import Dataset
import random

class ISBI_Loader(Dataset):
    def __init__(self, data_path):
        # 初始化函数,读取所有data_path下的图片
        self.data_path = data_path
        self.imgs_path = glob.glob(os.path.join(data_path, 'image/*.png'))

    def augment(self, image, flipCode):
        # 使用cv2.flip进行数据增强,filpCode为1水平翻转,0垂直翻转,-1水平+垂直翻转
        flip = cv2.flip(image, flipCode)
        return flip
        
    def __getitem__(self, index):
        # 根据index读取图片
        image_path = self.imgs_path[index]
        # 根据image_path生成label_path
        label_path = image_path.replace('image', 'label')
        # 读取训练图片和标签图片
        image = cv2.imread(image_path)
        label = cv2.imread(label_path)
        # 将数据转为单通道的图片
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)
        image = image.reshape(1, image.shape[0], image.shape[1])
        label = label.reshape(1, label.shape[0], label.shape[1])
        # 处理标签,将像素值为255的改为1
        if label.max() > 1:
            label = label / 255
        # 随机进行数据增强,为2时不做处理
        flipCode = random.choice([-1, 0, 1, 2])
        if flipCode != 2:
            image = self.augment(image, flipCode)
            label = self.augment(label, flipCode)
        return image, label

    def __len__(self):
        # 返回训练集大小
        return len(self.imgs_path)

    
if __name__ == "__main__":
    isbi_dataset = ISBI_Loader("data/train/")
    print("数据个数:", len(isbi_dataset))
    train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
                                               batch_size=2, 
                                               shuffle=True)
    for image, label in train_loader:
        print(image.shape)

使用dataset.py文件加载数据的输出结果为:

image-20250118211222608

4.3 U-Net模型搭建

如果完全按照论文的结构,模型输出的尺寸会稍微小于图片输入的尺寸,如果使用论文的网络结构需要在结果输出后,做一个 resize 操作。为了省去这一步,我们可以修改网络,使网络的输出尺寸正好等于图片的输入尺寸。

创建unet_parts.py文件,编写如下代码:

import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

创建unet_model.py文件,编写如下代码:

"""利用unet_parts.py中的部件组装形成完整的Unet网络"""

import torch.nn.functional as F
from .unet_parts import *

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

if __name__ == '__main__':
    net = UNet(n_channels=3, n_classes=1)
    print(net)

这样调整过后,网络的输出尺寸就与图片的输入尺寸相同了。

4.4 训练U-Net

我们今天的任务,只需要分割出细胞边缘,也就是一个很简单的二分类任务,所以我们可以使用BCEWithLogitsLoss

BCEWithLogitsLoss是 Pytorch 提供的用来计算二分类交叉熵的函数:

img

这个公式就是 Logistic 回归的损失函数,它利用的是 Sigmoid 函数阈值在[0,1]这个特性来进行分类的。

from model.unet_model import UNet
from utils.dataset import ISBI_Loader
from torch import optim
import torch.nn as nn
import torch

def train_net(net, device, data_path, epochs=40, batch_size=1, lr=0.00001):
    # 加载训练集
    isbi_dataset = ISBI_Loader(data_path)
    train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
                                               batch_size=batch_size, 
                                               shuffle=True)
    # 定义RMSprop算法
    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    # 定义Loss算法
    criterion = nn.BCEWithLogitsLoss()
    # best_loss统计,初始化为正无穷
    best_loss = float('inf')
    # 训练epochs次
    for epoch in range(epochs):
        # 训练模式
        net.train()
        # 按照batch_size开始训练
        for image, label in train_loader:
            optimizer.zero_grad()
            # 将数据拷贝到device中
            image = image.to(device=device, dtype=torch.float32)
            label = label.to(device=device, dtype=torch.float32)
            # 使用网络参数,输出预测结果
            pred = net(image)
            # 计算loss
            loss = criterion(pred, label)
            print('Loss/train', loss.item())
            # 保存loss值最小的网络参数
            if loss < best_loss:
                best_loss = loss
                torch.save(net.state_dict(), 'best_model.pth')
            # 更新参数
            loss.backward()
            optimizer.step()

if __name__ == "__main__":
    # 选择设备,有cuda用cuda,没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 加载网络,图片单通道1,分类为1。
    net = UNet(n_channels=1, n_classes=1)
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 指定训练集地址,开始训练
    data_path = "data/train/"
    train_net(net, device, data_path)

训练的输出LOSS为:

D:\Pytorch\Anaconda3\envs\d2l\python.exe E:/Unet/train.py
Loss/train 0.7607325911521912
Loss/train 0.7042417526245117
Loss/train 0.6359086036682129
Loss/train 0.5970470905303955
Loss/train 0.5956597328186035
Loss/train 0.5168007016181946
Loss/train 0.509236216545105
Loss/train 0.4616139233112335
Loss/train 0.4775354862213135
Loss/train 0.44973015785217285
Loss/train 0.4584605395793915
Loss/train 0.4683246314525604
Loss/train 0.42488470673561096
Loss/train 0.4471128582954407
Loss/train 0.4211571216583252
..............................
Loss/train 0.08856159448623657
Loss/train 0.09392605721950531
Loss/train 0.10895076394081116
Loss/train 0.10066412389278412
Loss/train 0.10104762017726898
Loss/train 0.09606537222862244
Loss/train 0.09012659639120102
Loss/train 0.10549496114253998
Loss/train 0.0879889726638794
Loss/train 0.10525484383106232
Loss/train 0.10839355736970901

Process finished with exit code 0
4.5 预测

模型训练好了,我们可以用它在测试集上看下效果。在工程根目录创建 predict.py 文件,编写如下代码:

import glob
import numpy as np
import torch
import os
import cv2
from model.unet_model import UNet

if __name__ == "__main__":
    # 选择设备,有cuda用cuda,没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 加载网络,图片单通道,分类为1。
    net = UNet(n_channels=1, n_classes=1)
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 加载模型参数
    net.load_state_dict(torch.load('best_model.pth', map_location=device))
    # 测试模式
    net.eval()
    # 读取所有图片路径
    tests_path = glob.glob('data/test/*.png')
    # 遍历所有图片
    for test_path in tests_path:
        # 保存结果地址
        save_res_path = test_path.split('.')[0] + '_res.png'
        # 读取图片
        img = cv2.imread(test_path)
        # 转为灰度图
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        # 转为batch为1,通道为1,大小为512*512的数组
        img = img.reshape(1, 1, img.shape[0], img.shape[1])
        # 转为tensor
        img_tensor = torch.from_numpy(img)
        # 将tensor拷贝到device中,只用cpu就是拷贝到cpu中,用cuda就是拷贝到cuda中。
        img_tensor = img_tensor.to(device=device, dtype=torch.float32)
        # 预测
        pred = net(img_tensor)
        # 提取结果
        pred = np.array(pred.data.cpu()[0])[0]
        # 处理结果
        pred[pred >= 0.5] = 255
        pred[pred < 0.5] = 0
        # 保存图片
        cv2.imwrite(save_res_path, pred)

运行完后,在data/test目录下,看到预测结果:

image-20250118211959195

5. U-Net与FCN的关系

U-Net是一种基于FCN的图像分割模型,它使用了FCN的思路和结构,并在其基础上添加了编码器-解码器结构,以提高分割的准确性和效率。

U-Net与FCN不同的是,U-Net的网络结构为对称的“U”型。具体来说,U-Net的下采样过程与FCN类似,特征图经过两次卷积操作以及一次池化操作,尺寸缩小为原来的一半,维度扩张为两倍。但是U-Net的上采样过程与FCN不同,相比起FCN的直接从底层特征图进行32倍或16倍等的上采样,U-Net的上采样过程与下采样过程对称,特征图经过两次卷积以及一次转置卷积操作后,尺寸扩大到原来两倍,维度为原来一半。

同时,U-Net的跳跃连接也与FCN不同。U-Net的跳跃连接将下采样时同尺寸的特征图裁剪后,与上采样时的特征图在维度上融合在一起,并作为解码器的下一个模块的输入。而不是FCN的将特征图相加后,直接上采样至原图的尺寸。

6. U-Net与生成模型

参考:深入浅出完整解析Stable Diffusion中U-Net的前世今生与核心知识 - 知乎

Unet为何可以作为生成模型(Diffusion)的backbone?/为何在AIGC时代,U-Net成为了Stable Diffusion这个划时代模型的关键结构?

这主要归结为U-net的四个特质:

  1. U-Net中Encoder模块的压缩特质。作为Encoder模块最初的应用,输入的图像经过下采样,抽取出比原图小得多的高维特征,相当于进行了压缩操作。这和Stable diffusion的latent逻辑不谋而合。
  2. U-Net中Decoder模块的去噪特质,作为Decoder模块最初的应用。
  3. U-Net整体结构上的简洁、稳定和高效,使得其在Stable Diffusion中能够从容的迭代去噪声,能够撑起Stable Diffusion的整个图像生成逻辑。
  4. Encoder-Decoder结构的强兼容性,让U-Net不管是在分割领域,还是在生成领域,都能和Transformer等新生代模型的从容融合。
6.1 Stable Diffusion中的U-Net

img

Stable Diffusion中的U-Net包含约860M的参数,在float32的精度下,约占3.4G的存储空间。

在上图中可以看到,U-Net是Stable Diffusion中的核心模块。U-Net主要在“扩散”循环中对高斯噪声矩阵进行迭代降噪,并且每次预测的噪声都由文本和timesteps进行引导,将预测的噪声在随机高斯噪声矩阵上去除,最终将随机高斯噪声矩阵转换成图片的隐特征。

在U-Net执行“扩散”循环的过程中,Content Embedding始终保持不变,而Time Embedding每次都会发生变化。每次U-Net预测的噪声都在Latent特征中减去,并且将迭代后的Latent作为U-Net的新输入。

6.2 U-Net在生成模型中的改进

Stable Diffusion中的U-Net,在Encoder-Decoder结构的基础上,增加了Time Embedding模块,Spatial Transformer(Cross Attention)模块和self-attention模块。

(1)Time Embedding模块

首先,什么是Time Embedding呢?

Time Embedding(时间嵌入)是一种在时间序列数据中用于表示时间信息的技术。时间序列数据是指按照时间顺序排列的数据,例如股票价格、天气数据、传感器数据等。时间嵌入的目的是将时间作为一个特征进行编码,以便在深度学习模型中更好地学习时间相关性特征。

Time Embedding的基本思想是将时间信息映射到一个连续的向量空间,使得时间之间的关系可以被模型学习和利用。

Time Embedding的使用可以帮助深度学习模型更好地理解时间相关性,从而提高模型的性能。比如在Stable Diffusion中,将Time Embedding引入U-Net中,帮助其在扩散过程中从容预测噪声。

Stable Diffusion需要迭代多次对噪音进行逐步预测,使用Time Embedding就可以将time编码到网络中,从而在每一次迭代中让U-Net更加合适的噪声预测。

讲完Time Embedding的核心基础知识,我们再解析一下Stable Diffusion中U-Net的Time Embeddings模块是如何构造的:

img

可以看到,Time Embeddings模块 + Encoder模块中原本的卷积层,组成了一个Residual Block结构。它包含两个卷积层,一个Time Embedding和一个skip Connection。而这里的全连接层将Time Embedding变换为和Latent Feature一样的维度。最后通过两者的加和完成time的编码。

(2)Spatial Transformer(Cross Attention)模块

在Stable Diffusion中,使用了Spatial Transformer来表示类Cross Attention模块。

Cross Attention是一种多头注意力机制,它可以在两个不同的输入序列之间建立关联,并且可以将其中一个输入序列的信息传递给另一个输入序列。

在计算机视觉中,Cross Attention可以用于将图像与文本之间的关联建立。例如,在图像字幕生成任务中,Cross Attention可以将图像中的区域与生成的文字之间建立关联,以便生成更准确的描述。

Stable Diffusion中使用Cross Attention模块控制文本信息和图像信息的融合交互,通俗来说,控制U-Net把噪声矩阵的某一块与文本里的特定信息相对应。

讲完Cross Attention的核心基础知识,我们再解析一下Stable Diffusion中U-Net的Cross Attention模块是如何构造的:

img

可以看到,Latent Feature和Context Embedding作为输入,将两者进行Cross Attenetion操作,将图像信息和文本信息进行了融合,整体上是一个经典的Transformer流程。

6.3 U-Net在生成模型中训练目标的变化

参考:The Illustrated Stable Diffusion – Jay Alammar – Visualizing machine learning one concept at a time.

在Stable Diffusion中,U-Net在不断的训练过程中主要学会了一件事,那就是去噪!去噪!去噪!

想要让U-Net能够高效去噪,并获得图像的隐特征,我们就要让U-Net知道什么是噪声数据。

于是我们在训练的预处理过程中,向训练集有策略地加入噪声。

这个加噪策略主要包括设定不同级别的噪声,比如说0-100共101个强度的噪声,在每个Batch中,随机加入1-n个101强度序列中的噪声,生成噪声图片。

img

加噪+噪声强度+加噪次数+原数据集,构成了Stable Diffusion中U-Net训练数据的基石。

img

有了数据预处理的大逻辑,在训练过程中,U-Net需要在已知噪声强度的条件下,不断学习提升从噪声图片中计算出噪声的能力。

需要注意的是,Stable Diffusion中的U-Net并不直接输出无噪声的原数据,而是去预测原数据上所加过的噪声

img

Stable Diffusion中U-Net的训练过程

如上图所示,Stable Diffusion中U-Net的训练一共分四步:

  1. 从训练集中选取一张加噪过的图片和噪声强度,比如上图的加噪街道图和噪声强度3。
  2. 将数据输入U-Nnet,并且预测噪声矩阵。
  3. 将预测的噪声矩阵和实际噪声矩阵(Label)进行误差的计算。
  4. 通过反向传播更新U-Net的参数。

在推理阶段中,我们将U-Net预测的噪声不断在噪声图片中减去就能恢复出图片的隐特征了。

当我们完成了U-Net在Stable Diffusion中的训练,如果我们再将噪声强度和噪声图输入U-Net,那么U-Net就能较准确地预测出有加在原素材上的噪声:

img

Stable Diffusion中U-Net预测噪声

有了U-Net对噪声的强预测能力,在Stable Diffusion的推理过程中,我们就可以使用U-Net循环预测噪声,并在噪声图上逐步减去这些被预测出来的噪声,从而得到一个我们想要的高质量的图像隐特征,去噪流程如下图所示:

img

总结

本次周报深入探讨了U-Net网络在医学图像分割和生成模型中的应用。通过对U-Net结构的逐层解析,结合数据增强和模型训练,验证了其在语义分割任务中的卓越表现。实验表明,U-Net凭借其特有的跳跃连接机制,能够在上下文信息捕获与精确定位之间取得良好的平衡。此外,U-Net在Stable Diffusion中的应用进一步彰显了其在生成模型领域的重要性,为未来的研究和优化提供了方向。


http://www.kler.cn/a/509947.html

相关文章:

  • MySQL程序之:连接到服务器的命令选项
  • 嵌入式杂谈——什么是DMA?有什么用?
  • 26个开源Agent开发框架调研总结(一)
  • ZooKeeper 核心概念与机制深度解析
  • VS Code--常用的插件
  • 数据库的DML
  • 【视觉惯性SLAM:十六、 ORB-SLAM3 中的多地图系统】
  • 深入探索Go语言中的临时对象池:sync.Pool
  • Vue2.0的安装
  • K210视觉识别模块
  • 向harbor中上传镜像(向harbor上传image)
  • 模块化架构与微服务架构,哪种更适合桌面软件开发?
  • 【Unity】使用UniRx来快速完成Unity中的信号层开发工作。
  • Navicat Premium 数据可视化
  • 基于SSM汽车美容管家【提供源码+答辩PPT+文档+项目部署】(高质量源码,可定制,提供文档,免费部署到本地)
  • 【JSqlParser】Java使用JSqlParser解析SQL语句总结
  • 事件委托,其他事件,电梯导航,固定导航
  • Linux 音视频入门到实战专栏(视频篇)视频编解码 MPP
  • Apache SeaTunnel 荣登 2024 年度中间件开源项目 Top 50 榜单
  • Kali环境变量技巧(The Environment Variable Technique Used by Kali
  • 鸿蒙(HarmonyOS)的开发
  • TypeScript 使用 VSCode 简介
  • 算法4(力扣206)-反转链表
  • Hack The Box-Starting Point系列Oopsie
  • TCP Window Full是怎么来的
  • 游戏画质升级史的思考