《Discriminative Class Tokens for Text-to-Image Diffusion Models》ICCV2023
摘要
论文讨论了文本到图像扩散模型的最新进展,这些模型能够生成多样化和高质量的图像。然而,生成的图像常常缺乏细节,并且由于输入文本的歧义性,容易产生错误。为了解决这些问题,作者提出了一种非侵入式的微调技术,利用预训练分类器的判别信号来指导生成过程,从而在保留自由形式文本表达潜力的同时,实现高精度。
概述
拟解决的问题:文本到图像的扩散模型在处理含有词汇歧义的输入或生成细节时表现不佳。此外,使用标记数据集训练的模型,由于数据集规模较小,限制了模型的表达能力,影响了生成图像的质量和多样性。
创新之处:
- 提出了一种微调技术,通过迭代修改文本到图像扩散模型中单个输入标记的嵌入,使用分类器来引导图像生成,使其更接近给定的目标类别。
- 该方法快速且不需要类别内图像集合或重新训练耐噪声分类器。
- 能够在低资源环境下用于增强训练数据,并且能够揭示用于训练引导分类器的数据信息。
方法
- 引入了一个与外部分类器标签类对应的标记(),通过迭代生成新图像并优化标记表示,以根据预训练分类器提高类别概率。
- 使用了一种新技术——梯度跳跃,它只将梯度传播通过扩散过程的最后阶段。
- 通过生成与目标类相关的图像,同时保留预训练扩散模型的全部表达能力,避免了对标记图像的训练。
提出了一种新颖的微调技术,通过在文本到图像扩散模型中引入一个与预训练分类器标签相对应的判别性标记 ,来解决输入文本中的词汇歧义问题并增强生成图像的细节表现。该技术通过迭代优化这个标记的嵌入表示,利用分类器的反馈来引导图像生成过程,从而生成更加准确和细致的图像,而无需重新训练整个模型或依赖于特定类别的图像集合。这种方法不仅提高了生成图像的质量,还保持了模型对自由形式文本的表达能力,同时避免了对分类器进行噪声数据的再训练。
3.1 条件扩散模型
训练条件扩散模型时,目标是学习一个过程,它能够预测在每一步添加的噪声,同时考虑条件输入。这通常通过最小化一个损失函数来实现,该函数衡量模型预测的噪声与实际噪声之间的差异。
其中:
在生成图像时,我们通常希望模型能够生成特定类别的图像。例如,如果输入文本是“一只猫”,我们希望生成的图像是猫的图像,而不是其他任何物体。为了实现这一点,可以利用分类器来引导扩散过程,使其偏向于生成特定类别的图像。在条件扩散模型中,可以通过使用预训练分类器的梯度信息来指导生成过程。这里的“梯度”是指损失函数相对于模型参数的导数,它指示了如何调整参数以最小化损失函数。在这种情况下,分类器的梯度可以用来调整生成模型的参数,使其生成的图像更符合特定类别的特征。
缺点:
- 当使用分类器指导扩散过程时,分类器需要在整个生成过程中对每一步产生的部分去噪图像进行评估。这意味着分类器必须能够准确地处理和理解在不同去噪阶段的图像,包括那些仍然包含噪声的图像。
- 在生成图像的每一步中,都需要利用分类器的输出来指导图像的生成方向。这就意味着分类器必须在生成过程的每个阶段都被调用,这会增加整体的计算负担和延迟。
为了解决这个问题,提出了一种无分类器的方法。这种方法不依赖于图像分类器的梯度,而是通过对条件 和无条件 、去噪模块之间的差异进行建模来近似隐式分类器的梯度。条件模块和无条件模块使用相同的 参数化,条件网络通过使用空句子变为无条件的。最终的去噪网络正式表示如下:
其中 w 是决定条件引导强度的超参数。
3.2 判别令牌嵌入
判别令牌是一种特殊的标记,它被嵌入到文本到图像的扩散模型中,用于代表特定的类别信息。这些令牌与预训练的分类器相关联,目的是在生成过程中引入类别特定的指导信号。
判别令牌的嵌入向量通常初始化为与目标类别相关的已知标记的嵌入,例如,如果目标是生成特定种类的鸟类图像,判别令牌的嵌入可能会初始化为“鸟”这个词的嵌入。这种初始化有助于模型更快地学习并适应特定的类别特征。
迭代优化过程:
- 生成与优化:在生成图像的过程中,模型会使用包含判别令牌的文本提示(如“一张具有的老虎猫的照片”)。这里的“Sc”代表判别令牌,它在每次迭代中被优化以更好地代表目标类别。
- 分类器反馈:生成的图像被送入预训练的分类器,分类器的输出(如类别概率分布)被用来提供反馈,指导判别令牌的进一步优化。
- 损失函数:通常使用交叉熵损失函数来衡量分类器对生成图像的分类结果与目标类别之间的差异。通过最小化这个损失,判别令牌的嵌入被调整以提高生成图像的类别准确性。
3.3 梯度跳跃
为了提高训练效率并减少资源消耗,论文中提到了“梯度跳跃”技术。在这种技术中,只有扩散过程的最后阶段(即最后的去噪步骤)会更新判别令牌的嵌入。这种方法减少了在每一步中都需要反向传播的计算负担,同时仍然能够有效地优化判别令牌。