当前位置: 首页 > article >正文

【AIGC】CFG:基于扩散模型分类器差异引导

摘要

分类器指导是最近引入的一种方法,在训练后在条件扩散模型中权衡模式覆盖率和样本保真度,在精神上与其他类型的生成模型中的低温采样或截断相同。分类器引导将扩散模型的分数估计与图像分类器的梯度相结合,因此需要训练与扩散模型分离的图像分类器。它还提出了一个问题,即是否可以在没有分类器的情况下执行指导。我们表明,如果没有这样的分类器,引导确实可以由纯生成模型执行:在我们称之为无分类器指导的情况下,我们联合训练一个条件扩散模型和一个无条件扩散模型,我们结合得到的条件和无条件分数估计来获得样本质量和多样性之间的权衡,类似于使用分类器指导获得的。

引言

        分类器引导将扩散模型的分数估计与a的对数概率的输入梯度混合在一起的分类器。通过改变分类器梯度的强度,Dhariwal和Nichol可以权衡Inception分数(Salimans等人,2016)和FID分数(Heusel等人,2017)(或精度和召回率),其方式类似于改变BigGAN的截断参数。

 

       针对64x64 ImageNet扩散模型的malamute类的无分类器指导。从左到右:增加无分类器引导的数量,从左边的非引导样本开始。

       引导器对三个高斯混合的影响,每个混合分量代表一个类条件下的数据。最左边的图是非引导的边际密度。从左到右是随引导强度增加的归一化引导条件的混合密度

        我们感兴趣的是是否可以在没有分类器的情况下进行分类器引导。分类器引导使扩散模型训练管道变得复杂,因为它需要训练一个额外的分类器,而且这个分类器必须在有噪声的数据上训练,所以通常不可能插入一个预训练的分类器。此外,由于分类器引导在采样期间混合了分数估计和分类器梯度,分类器引导的扩散采样可以被解释为试图将图像分类器与基于梯度的对抗性攻击混淆。这就提出了一个问题:分类器指导是否能够成功地提高基于分类器的度量,比如FID和Inception分数(is),仅仅是因为它与这些分类器是对立的。在分类器梯度方向上的步进也与GAN训练有一些相似之处,特别是使用非参数生成器;这也提出了一个问题,即分类器引导的扩散模型是否在基于分类器的指标上表现良好,因为它们开始类似于gan,而gan已经在这些指标上表现良好。 

        为了解决这些问题,我们提出了无分类器的引导方法,即完全避免使用任何分类器的引导方法。与在图像分类器的梯度方向上采样不同,无分类器引导混合了条件扩散模型和联合训练的无条件扩散模型的分数估计。通过扫过混合权重,我们获得了类似于分类器引导所获得的FID/IS权衡。我们的无分类器引导结果表明,纯生成扩散模型能够与其他类型的生成模型合成极高保真度的样本。 

算法

         虽然分类器指导成功地权衡了截断或低温采样预期的IS和FID,但它仍然依赖于图像分类器的梯度,我们试图消除分类器,原因见第1节。在这里,我们描述了无分类器指导,在没有这种梯度的情况下实现了相同的效果。无分类器指导是修改\epsilon_{\theta}(z_{\lambda},c)与分类器引导具有相同的效果的替代方法,但没有分类器。算法 1 和 2 详细描述了使用无分类器指导进行训练和采样。

算法1

 训练时

算法2

 推理时

讨论

         我们的无分类器指导方法最实际的优点是它极其简单:在训练期间(随机放弃条件)和在采样期间(混合条件和无条件分数估计)只需要对代码进行一行更改。相比之下,分类器引导使训练管道变得复杂,因为它需要训练额外的分类器。这个分类器必须在有噪声的zλ上进行训练,因此不可能插入一个标准的预训练分类器。

代码片段


if args.conditioning_dropout_prob is not None:
    random_p = torch.rand(bsz, device=latents.device, generator=generator)
    # Sample masks for the edit prompts.
    prompt_mask = random_p < 2 * args.conditioning_dropout_prob
    prompt_mask = prompt_mask.reshape(bsz, 1, 1)
    # Final text conditioning.
    encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)

    # Sample masks for the original images.
    image_mask_dtype = original_image_embeds.dtype
    image_mask = 1 - ((random_p >= args.conditioning_dropout_prob).to(image_mask_dtype)
                        * (random_p < * args.conditioning_dropout_prob).to(image_mask_dtype)
                    )
    image_mask = image_mask.reshape(bsz, 1, 1, 1)
    # Final image conditioning.
    original_image_embeds = image_mask * original_image_embeds

    # Concatenate the `original_image_embeds` with the `noisy_latents`.
    concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1)

参考

diffusers/src/diffusers/training_utils.py at main · huggingface/diffusers · GitHub


http://www.kler.cn/a/307303.html

相关文章:

  • 网络基础Linux
  • [DEBUG] 服务器 CORS 已经允许所有源,仍然有 304 的跨域问题
  • 爬虫补环境案例---问财网(rpc,jsdom,代理,selenium)
  • JavaScript 观察者设计模式
  • 为什么hbase在大数据领域渐渐消失
  • 文件输入输出——NOI
  • JavaScript 函数 function
  • 用 nextjs 创建 Node+React Demo
  • WebGL入门(048):OES_draw_buffers_indexed 简介、使用方法、示例代码
  • Python---爬虫
  • Leetcode-轮转数组
  • 复现OpenVLA:开源的视觉-语言-动作模型及原理详解
  • 【Go开发】Go语言结构体,与java类不一样的定义方式
  • 推荐|基于springBoot智能推荐的卫生健康系统设计与实现(源码+论文+数据库)
  • 【附源码】用Python开发一个音乐下载工具,并打包EXE文件,所有音乐都能搜索下载!
  • el-table 的单元格 + 图表 + 排序
  • 动手学深度学习(pytorch土堆)-03常见的Transforms
  • 图论篇--代码随想录算法训练营第五十六天打卡| 108. 冗余连接,109. 冗余连接II
  • 【SQL】百题计划:SQL排序Order by的使用。
  • Flutter Error: Type ‘UnmodifiableUint8ListView‘ not found
  • 刷题DAY36
  • 初中生物--5.单细胞生物
  • VuePress搭建文档网站/个人博客(详细配置)主题配置-导航栏配置
  • 【开源免费】基于SpringBoot+Vue.JS企业客户管理系统(JAVA毕业设计)
  • Linux命令:文本处理工具sed详解
  • django中F()和Q()的用法