当前位置: 首页 > article >正文

CFG 蒸馏:On Distillation of Guided Diffusion Models

CFG 蒸馏:On Distillation of Guided Diffusion Models

TL; DR:本文提出一种两阶段的扩散模型蒸馏方法,第一阶段将 CFG 蒸馏到模型内部,避免在 CFG 文生图时执行两次模型推理;第二阶段使用类似 Progressive Distillation 的方法进行蒸馏,降低推理所需步数,从两方面提高文本引导的扩散模型的生图速度。

方法

本文提出的蒸馏方法分为两个阶段,第一阶段是对 CFG 进行蒸馏,将 CFG 的能力蒸馏到模型内部,避免在推理生图时执行两次模型前向,第二阶段是对模型生图的去噪步数进行蒸馏,降低推理所需步数。

阶段一:CFG 蒸馏

Classifier-Free Guidance(CFG)是目前大多数主流文生图模型使用的条件生图方法。该方法在训练时按照一定比例同时训练模型的条件生图能力和无条件生图能力。在推理时,指定一个文本提示词和一个引导权重参数 w w w,然后分别推理条件生图结果 x ^ c , θ \hat{\mathbf{x}}_{c,\theta} x^c,θ 和无条件生图结果 x ^ θ \hat{\mathbf{x}}_\theta x^θ,取二者的加权和为最终的生图结果:
x ^ θ w = ( 1 + w ) x ^ c , θ − w ( x ^ θ ) \hat{\mathbf{x}}_\theta^w=(1+w)\hat{\mathbf{x}}_{c,\theta}-w(\hat{\mathbf{x}}_\theta) x^θw=(1+w)x^c,θw(x^θ)
引入 CFG 的好处是我们可以在推理生图时,通过调整引导权重 w w w 来在生成质量和多样性之间进行 trade-off。但坏处是我们需要在推理时分别对条件生成和无条件生成模型都执行一次推理,大大增加了生图时间。

本文所提方法,就是要将 CFG 的这种生成质量和多样性间进行权衡的能力蒸馏到模型内部,从而在生图时既能调整权重 w w w,又无需执行两次模型推理。

具体来说,在 w w w 的一个取值区间 [ w min , w max ] [w_\text{min},w_\text{max}] [wmin,wmax] 内,给定参数为 θ \theta θ 的教师模型,使用如下目标函数进行蒸馏训练学生模型参数 η 1 \eta_1 η1
E w ∼ U [ w min , w max ] , t ∼ U [ 0 , 1 ] , x ∼ p data ( x ) [ ω ( λ t ) ∣ ∣ x ^ η 1 ( z t , w ) − x ^ θ w ( z t ) ∣ ∣ 2 2 ] \mathbb{E}_{w\sim \mathcal{U}[w_\text{min},w_\text{max}],t\sim\mathcal{U}[0,1],\mathbf{x}\sim p_\text{data}(\mathbf{x})}[\omega(\lambda_t)||\hat{\mathbf{x}}_{\eta_1}(\mathbf{z}_t,w)-\hat{\mathbf{x}}_\theta^w(\mathbf{z}_t)||_2^2] EwU[wmin,wmax],tU[0,1],xpdata(x)[ω(λt)∣∣x^η1(zt,w)x^θw(zt)22]
其中 x ^ θ w ( z t ) \hat{\mathbf{x}}_\theta^w(\mathbf{z}_t) x^θw(zt) 是教师模型使用 CFG 进行蒸馏的输出结果。公式中的文本条件 c c c 都省略了。 ω ( λ t ) \omega(\lambda_t) ω(λt) 是扩散模型训练形式化的一个权重参数,详情可参考 VDM。

可以看到,我们的学生模型将引导参数 w w w 作为了一个直接的输入参数,对应的,模型结构也有一些改变来适配。为了更好地捕获特征,本文提取 w w w 傅里叶 embedding,随后使用类似 timestep 引入扩散模型的方式,将 w w w 也引入。除此之外,学生模型的模型结构均与教师模型相同,并使用其参数作为初始化参数。

阶段二:Timestep 蒸馏

第二阶段,是对时间步进行整理,使用的方法与 Progressive Distillation 非常类似。简单来说,就是学生模型学习一步去噪得到教师模型两步去噪的结果,从而将采样生图所需的步数减半。完成一次整理后将学生模型作为下一轮蒸馏的教师模型,如此循环往复,最终能得到 1-4 步生图的模型。该阶段训练完成后的模型参数记为 η 2 \eta_2 η2

采样生成

在两阶段蒸馏训练完成后,我们可以高效地进行采样生成。本文介绍了确定采样和随机采样两种方式,其中确定采样就是使用 DDIM 进行采样。随机采样则是对确定采样的步长加倍,并在每步之后加入一定的随机噪声。

在这里插入图片描述

总结

CFG 能够很好地在生图质量和多样性之间进行权衡,但是两次模型推理的开销确实太大,将这个调节参数蒸馏到模型内部是一个很好的想法,能够有效提高条件图的速度。最近的 Flux dev 也是对 Flux pro 进行了 CFG 蒸馏,不知是不是就是使用的本文方法。


http://www.kler.cn/news/360629.html

相关文章:

  • 【景观生态学实验】实验四 景观指数计算
  • multi-Head Attention
  • Vue--》掌握自定义依赖引入的最佳实践
  • blender 理解 积木组合 动画制作 学习笔记
  • C# 数据集
  • sql server xml
  • Egg.js使用ejs快速自动生成resetful风格的CRUD接口
  • 数据门户:企业数字化转型的关键作用
  • Oracle里面,with ... as 用法介绍
  • 软考系统分析师知识点十三:软件需求工程
  • 【论文笔记】Adversarial Diffusion Distillation
  • Flux.using 使用说明书
  • LeetCode第101题. 对称二叉树
  • c语言操作符xiangjie
  • 10 django管理系统 - 管理员管理 - 新建管理员(通过模态框和ajax实现)
  • 快乐数--双指针
  • MSE Loss、BCE Loss
  • 电商大数据获取渠道分享--官方接口、爬虫、第三方如何选择?
  • 【FAQ】HarmonyOS SDK 闭源开放能力 —Map Kit(3)
  • Taro构建的H5页面路由切换返回上一页存在白屏页面过渡