Stable Diffusion中U-Net的前世今生与核心知识
🌺系列文章推荐🌺
扩散模型系列文章正在持续的更新,更新节奏如下,先更新SD模型讲解,再更新相关的微调方法文章,敬请期待!!!(本文及其之前的文章均已更新)
SD模型原理:
- Stable Diffusion概要讲解
- Stable diffusion详细讲解
- Stable Diffusion的加噪和去噪详解
- Diffusion Model
- Stable Diffusion核心网络结构——VAE
- Stable Diffusion核心网络结构——CLIP Text Encoder
- Stable Diffusion核心网络结构——U-Net
- Stable Diffusion中U-Net的前世今生与核心知识
- SD模型性能测评
- Stable Diffusion经典应用场景
- SDXL的优化工作
微调方法原理:
- DreamBooth
- LoRA
- LORA及其变种介绍
- ControlNet
- ControlNet文章解读
- Textual Inversion 和 Embedding fine-tuning
Stable Diffusion中U-Net的前世今生与核心知识
摘录于:https://zhuanlan.zhihu.com/p/642354007
目录
目录
Stable Diffusion中U-Net的前世今生与核心知识
1. 传统深度学习时代的U-Net
1.1 U-Net的“AI江湖”印象
1.2 U-Net的核心结构与细节
1.3 是什么让U-Net通向AIGC
2. Stable Diffusion中的U-Net
2.1 U-Net在Stable Diffusion中扮演的角色
2.2 Stable Diffusion中U-Net的完整核心结构
2.3 U-Net在AIGC时代中的核心结构与细节
2.4 GroupNorm
选择建议
3. U-Net在Stable Diffusion中的训练和推理
3.1 U-Net在Stable Diffusion中的训练过程
3.2 U-Net在Stable Diffusion中的推理过程
1. 传统深度学习时代的U-Net
1.1 U-Net的“AI江湖”印象
在2015年,传统深度学习时代的早期,U-Net: Convolutional Networks for Biomedical Image Segmentation(U-Net)正式发表,图像分割领域迎来了它的“ResNet”。其具有简洁,高效,稳定的特性。
1.2 U-Net的核心结构与细节
(1) Encoder-Decoder结构
U-Net最经典的特征是其Encoder-Decoder的结构,这样的结构简洁且高效,并且具备对称的“艺术”美感,也让U-Net具备了极强的生命力与适应性。
其中左半部分的Encoder模块负责进行特征的提取与学习,Encoder模块可以由ResNet、VGG、EfficientNet等一流特征提取模型担任,所以Encoder模块具备较强的工程潜力与科研势能。与此同时Encoder模块可以增加对扰动噪声的鲁棒性,减少过拟合的风险,降低运算量以及增加感受野的大小等作用。
而右半部分的Decoder模块则负责将特征图恢复到原始分辨率,并使用skip-connection这个关键一招融合了浅层的位置信息与深层的语义信息。与此同时,Decoder模块和Encoder模块一样可以由ResNet、VGG、EfficientNet等一流模型担任,从而使得U-Net的变体非常繁荣,增加了工程“魔改”的可玩性。
(2) U-Net结构细节挖掘
讲完Encoder-Decoder结构的整体框架,向大家介绍一下Encoder-Decoder结构中的一些能够成为通用范式和经典Tricks的细节操作。
从上图的Encoder-Decoder结构中可以看到,U-Net是一个全卷积神经网络,网络最后一层使用了浅蓝色箭头,表示1*1卷积,其完全取代了全连接层,使得模型的输入尺寸不再受限制,极大增强了U-Net在各种应用场景的兼容性。
上图中的蓝色和白色框表示feature map,深蓝色箭头表示 3x3 卷积,padding=0 ,stride=1其用于特征提取。由于padding=0,所以每次经过卷积运算,feature map将有一定程度的下采样。深红色箭头表示max pooling,stride=2,用于降低维度。将卷积和max pooling两者结合,能够对feature map进行特征提取的同时从容进行下采样。
上图中的绿色箭头表示Upsample操作,对feature map进行上采样恢复维度。
Upsampling常用的方式有两种:转置卷积和插值。而在U-Net中,使用了bilinear双线性插值。
在Encoder和Decoder两个模块之间,使用skip-connection作为桥梁,用于特征融合,将浅层的位置信息与深层的语义信息进行concat操作。图中用灰色箭头表示skip-connection,其中“copy”就是concat操作,而“crop”则通过裁剪使得两个特征图尺寸一致。
Skip-connection(跳跃连接) 是深度学习中一种重要的技术,特别是在像U-Net这样的架构中,它发挥了至关重要的作用。下面我来详细解释它的原理和在U-Net中的作用。
1. Skip-connection的基本概念
Skip-connection指的是在神经网络的某些层之间跳过中间层,直接将较早层的输出连接到较后层的输入。这意味着网络中后面的层不仅依赖于前一层的输出,还直接获取了来自更早层的信息。这种连接方式最早在ResNet(Residual Networks)中被广泛应用,用来缓解深层网络中的梯度消失问题。
在U-Net中,skip-connection连接了编码器部分(左半部分)的特征图与解码器部分(右半部分)相对应的层。这种连接保留了编码阶段中较浅层的空间信息,并在解码阶段将这些信息重新加入到重建过程中。
2. 在U-Net中的作用
U-Net是一种常用于图像分割的神经网络,它的结构具有对称性,包括左半部分的编码器和右半部分的解码器。编码器逐步降低特征图的空间分辨率并提取高级语义特征,而解码器则负责恢复空间分辨率,最终输出与原始输入分辨率相同的图像。在这个过程中,skip-connection起到了以下几个重要作用:
2.1 融合浅层的位置信息和深层的语义信息
浅层的特征图:编码器的浅层主要保留了图像的位置信息和细节特征(如边缘、轮廓等),但语义信息较少。
深层的特征图:编码器的深层特征图包含丰富的语义信息(比如物体的类别、全局上下文等),但位置信息损失较多,特征图的分辨率也很低。
通过skip-connection,解码器在每一步恢复特征图分辨率时,能够直接获得来自编码器浅层的高分辨率位置信息,并将其与当前层的深层语义信息进行融合。这种结合方式既保留了图像的细节,又包含了深层的语义信息,从而使得模型能够更加精准地进行像素级的图像分割。
2.2 增强梯度流动,帮助训练深层网络
skip-connection也有助于梯度的反向传播,使梯度能够从解码器的输出层更顺利地传播到编码器的浅层,从而帮助训练更加稳定。类似于ResNet的残差连接,skip-connection通过增加信息流动的路径,减轻了深层网络中的梯度消失问题。
3. 具体实现
在U-Net中,skip-connection通常是在解码器的每个上采样(Up-sampling)步骤中,将编码器相应层的特征图与解码器当前层的特征图进行拼接(concatenate),即沿着通道维度将它们合并在一起。这样,解码器每一层既可以利用当前的高层语义特征,又可以利用来自编码器浅层的位置信息。
4. 总结
在U-Net中,skip-connection极大地提升了网络的性能,特别是在需要高分辨率输出的任务中(如图像分割)。它通过将编码器中的浅层信息传递到解码器部分,帮助模型在恢复高分辨率特征图时,同时利用深层的语义信息和浅层的位置信息。这种融合策略既保留了细节又增强了语义理解,使得U-Net在医学影像分割、遥感图像处理等任务中表现优异。
1.3 是什么让U-Net通向AIGC
讲完U-Net在传统深度学习时代的核心知识点与价值,为何在AIGC时代,U-Net成为了Stable Diffusion这个划时代模型的关键结构。
U-Net主要有以下四个特质:
- U-Net中Encoder模块的压缩特质。作为Encoder模块最初的应用,输入的图像经过下采样,抽取出比原图小得多的高维特征,相当于进行了压缩操作。这和Stable diffusion的latent逻辑不谋而合,随即在AIGC时代“文艺复兴”。
- U-Net中Decoder模块的去噪特质,作为Decoder模块最初的应用,在AIGC时代“文艺复兴”。
- U-Net整体结构上的简洁、稳定和高效,使得其在Stable Diffusion中能够从容的迭代去噪声,能够撑起Stable Diffusion的整个图像生成逻辑。
- Encoder-Decoder结构的强兼容性,让U-Net不管是在分割领域,还是在生成领域,都能和Transformer等新生代模型的从容融合。
正是这些特质让U-Net顺应了时代的潮流,AIGC时代里,依旧爆发出了鲜活的生命力与价值。
2. Stable Diffusion中的U-Net
2.1 U-Net在Stable Diffusion中扮演的角色
Stable Diffusion中的U-Net包含约860M的参数,在float32的精度下,约占3.4G的存储空间。
在上图中可以看到,U-Net是Stable Diffusion中的核心模块。U-Net主要在“扩散”循环中对高斯噪声矩阵进行迭代降噪【预测噪声】,并且每次预测的噪声都由文本和timesteps进行引导,将预测的噪声在随机高斯噪声矩阵上去除,最终将随机高斯噪声矩阵转换成图片的隐特征。
在U-Net执行“扩散”循环的过程中,Content Embedding始终保持不变,而Time Embedding每次都会发生变化。每次U-Net预测的噪声都在Latent特征中减去,并且将迭代后的Latent作为U-Net的新输入。
总的来说,如果说Stable Diffusion是“优化噪声的艺术”,那么U-Net将是这个“艺术”的核心主导者。
2.2 Stable Diffusion中U-Net的完整核心结构
在讲解Stable Diffusion中U-Net的各个核心模块之前,我们先看看U-Net在Stable Diffusion中的完整结构:
2.3 U-Net在AIGC时代中的核心结构与细节
Stable Diffusion中的U-Net,在Encoder-Decoder结构的基础上,增加了Time Embedding模块,Spatial Transformer(Cross Attention)模块和self-attention模块。
(1) Time Embedding模块
首先,什么是Time Embedding呢?
Time Embedding(时间嵌入)是一种在时间序列数据中用于表示时间信息的技术。时间序列数据是指按照时间顺序排列的数据,例如股票价格、天气数据、传感器数据等。时间嵌入的目的是将时间作为一个特征进行编码,以便在深度学习模型中更好地学习时间相关性特征。
Time Embedding的基本思想是将时间信息映射到一个连续的向量空间,使得时间之间的关系可以被模型学习和利用。
Time Embedding的使用可以帮助深度学习模型更好地理解时间相关性,从而提高模型的性能。比如在Stable Diffusion中,将Time Embedding引入U-Net中,帮助其在扩散过程中从容预测噪声。
Stable Diffusion需要迭代多次对噪音进行逐步预测,使用Time Embedding就可以将time编码到网络中,从而在每一次迭代中让U-Net更加合适的噪声预测。
讲完Time Embedding的核心基础知识,我们再解析一下Stable Diffusion中U-Net的Time Embeddings模块是如何构造的:
可以看到,Time Embeddings模块 + Encoder模块中原本的卷积层,组成了一个Residual Block结构。它包含两个卷积层,一个Time Embedding和一个skip Connection。而这里的全连接层将Time Embedding变换为和Latent Feature一样的维度。最后通过两者的加和完成time的编码。
(2) Spatial Transformer(Cross Attention)模块
在Stable Diffusion中,使用了Spatial Transformer来表示类Cross Attention模块。
按照惯例,我们先理解一下什么是Cross Attention?
Cross Attention是一种多头注意力机制,它可以在两个不同的输入序列之间建立关联,并且可以将其中一个输入序列的信息传递给另一个输入序列。
在计算机视觉中,Cross Attention可以用于将图像与文本之间的关联建立。例如,在图像字幕生成任务中,Cross Attention可以将图像中的区域与生成的文字之间建立关联,以便生成更准确的描述。
Stable Diffusion中使用Cross Attention模块控制文本信息和图像信息的融合交互,通俗来说,控制U-Net把噪声矩阵的某一块与文本里的特定信息相对应。
讲完Cross Attention的核心基础知识,我们再解析一下Stable Diffusion中U-Net的Cross Attention模块是如何构造的:
可以看到,Latent Feature和Context Embedding作为输入,将两者进行Cross Attenetion操作,将图像信息和文本信息进行了融合,整体上是一个经典的Transformer流程。
2.4 GroupNorm
Stable Diffusion中U-Net的一个细节Trick,那就是U-Net中全部采用GroupNorm进行归一化。
GroupNorm有如下的优点:
- 独立于Batch:GroupNorm不依赖于Batch大小,这使得它在处理小Batch数据或者Batch大小变化较大的情况时仍能保持稳定性。这在生成任务中尤其重要,因为它允许使用更小的Batch而不会牺牲性能。
- 提升训练稳定性:在GANs等生成任务中,模型训练可能非常不稳定。GroupNorm可以帮助增强模型的训练稳定性,从而产生更高质量的生成结果。
- 减少内存消耗:由于GroupNorm允许使用更小的Batch而不影响性能,因此可以减少训练期间的内存消耗,这对于资源限制较大的环境特别重要。
总结对比于其他归一化方法的差别
特性 | Batch Normalization (BN) | Layer Normalization (LN) | Group Normalization (GN) |
---|---|---|---|
归一化维度 | mini-batch 维度 | 单个样本内所有特征维度 | 单个样本内按通道划分的组 |
适合场景 | 大规模 batch,卷积神经网络(CNN) | 小 batch,序列模型(RNN、Transformer) | 任意 batch size,卷积网络(CNN),【SD模型中更好地保持图像的空间结构,在视觉任务中表现优于 LN】 |
不适合场景 | 小 batch,序列模型 | 大 batch,卷积网络 | 不适合需要动态调整组数的场景 |
正则化效果 | 有一定正则化效果 | 正则化效果较弱 | 正则化效果较好 |
推理阶段的复杂性 | 需要存储和使用全局统计量 | 无需全局统计量 | 无需全局统计量 |
计算效率 | 在大 batch 下效率高 | 适合小 batch,计算简单 | 灵活,但略微增加计算复杂度 |
选择建议
- 如果使用大 batch 且为 CNN,推荐使用 Batch Normalization (BN)。
- 如果 batch size 较小或 batch size = 1 或为序列模型(如 Transformer、RNN),推荐使用 Layer Normalization (LN)。
- 如果希望归一化效果较为灵活,且适用于小 batch 和卷积【视觉】模型,推荐使用 Group Normalization (GN)。【GN 更适合卷积神经网络,因为它能更好地保持图像的空间结构,在视觉任务中表现优于 LN。】
3. U-Net在Stable Diffusion中的训练和推理
3.1 U-Net在Stable Diffusion中的训练过程
在Stable Diffusion中,U-Net在不断的训练过程中主要学会了一件事,那就是去噪!去噪!还是tmd去噪!
想要让U-Net能够高效去噪,并获得图像的隐特征,我们就要让U-Net知道什么是噪声数据。
于是我们在训练的预处理过程中,向训练集有策略地加入噪声。
这个加噪策略主要包括设定不同级别的噪声,比如说0-100共101个强度的噪声,在每个Batch中,随机加入1-n个101强度序列中的噪声,生成噪声图片。
加噪+噪声强度+加噪次数+原数据集,构成了Stable Diffusion中U-Net训练数据的基石。
有了数据预处理的大逻辑,在训练过程中,U-Net需要在已知噪声强度的条件下,不断学习提升从噪声图片中计算出噪声【预测噪声】的能力。
需要注意的是,Stable Diffusion中的U-Net并不直接输出无噪声的原数据,而是去预测原数据上所加过的噪声。
如上图所示,Stable Diffusion中U-Net的训练一共分四步:
- 从训练集中选取一张加噪过的图片和噪声强度,比如上图的加噪街道图和噪声强度3。
- 将数据输入U-Net,并且预测噪声矩阵。
- 将预测的噪声矩阵和实际噪声矩阵(Label)进行误差的计算。
- 通过反向传播更新U-Net的参数。
论文中的训练过程如下:
训练过程的目标是训练去噪模型。通过训练Unet模型,该模型输入为和,输出为时刻预测的高斯噪声。即利用 和预测这一时刻的高斯噪声,通过最小化真实加入的噪声和预测的噪声之间的差距(目标是让模型尽可能准确地预测添加到 中的真实噪声 )。
解释:
- 数据集中采样一张图片
- 采样一个时间步
- 从正态分布中采样一个噪音
- 前面我们定义了如何在时间步时刻,在无需迭代的情况下前向加噪 ,然后用梯度下降(反向预测噪声)去训练网络模型预测噪声 的能力。
- 然后重复整个过程,直到模型最终收敛。
3.2 U-Net在Stable Diffusion中的推理过程
在推理阶段中,我们将U-Net预测的噪声不断在噪声图片中减去就能恢复出图片的隐特征了。
当我们完成了U-Net在Stable Diffusion中的训练,如果我们再将噪声强度和噪声图输入U-Net,那么U-Net就能较准确地预测出有加在原素材上的噪声:
有了U-Net对噪声的强预测能力,在Stable Diffusion的推理过程中,我们就可以使用U-Net循环预测噪声,并在噪声图上逐步减去这些被预测出来的噪声,从而得到一个我们想要的高质量的图像隐特征,去噪流程如下图所示:
论文中的推理过程如下:
推理/采样过程的目标是从纯噪声生成高质量图像。输入是随机高斯噪声或者加噪后的图像。输出是符合prompts的图像。
1. 初始化纯噪声(第1步):从标准正态分布 中采样一个随机噪声 ,作为起点。
2. 逐步去噪(第2-5步):从时间步 开始,逐步将噪声去除,恢复原始数据 。公式为:
- 是当前时间步的带噪图像。
- 是模型预测的噪声。
- 是随机噪声,用于在去噪过程中保持生成的随机性。
- 是一个与噪声相关的可调参数。
3. 输出生成的图像(第6步):经过次迭代后,最终得到去噪后的图像 。