Flow-matching论文阅读
我首先需要仔细阅读用户指令,明确用户的需求。用户希望我逐段翻译一篇关于流匹配生成模型的文章,并且在翻译的过程中,总结文章的主要证明公式要点。这意味着我的任务不仅仅是简单的翻译,还需要理解文章的内容,并且提取出文章中重要的数学公式,进行概括和提炼。
为了完成这个任务,我计划按文章的段落顺序进行翻译。在翻译每一段落的同时,我会识别出该段落的核心思想和关键信息。对于文章中出现的数学公式,我会特别关注,因为这些公式通常是文章证明和结论的核心支撑。
在翻译过程中,我会将文章的结构大致划分为几个部分,例如引言、预备知识、方法介绍(流匹配)、实验结果和相关工作等。这样可以帮助我更好地组织翻译内容,并确保总结的要点能够覆盖文章的主要方面。
对于每个公式,我需要理解其含义和作用。例如,在连续归一化流的预备知识部分,文章给出了概率密度路径和向量场的定义,以及如何通过常微分方程构建流。我会记录下这些基本概念和公式,因为它们是后续流匹配方法的基础。
在流匹配方法部分,文章提出了流匹配目标函数,并解释了如何通过条件概率路径和向量场来构建目标函数。我会重点关注流匹配目标函数的定义,以及条件流匹配目标函数的推导过程。同时,文章还讨论了高斯条件概率路径的特殊情况,包括扩散条件向量场和最优传输条件向量场。我会详细记录这些不同类型的条件概率路径及其对应的向量场公式,因为这些是流匹配方法的重要组成部分。
在实验结果部分,文章展示了流匹配模型在不同数据集上的性能表现。我会关注实验设置、评估指标以及实验结果,并在总结中概括这些关键信息,例如模型在图像质量和计算效率方面的优势。
最后,我会将翻译完成的文章和总结的证明公式要点整合在一起,确保翻译准确流畅,总结简洁明了,并且能够全面反映文章的主要内容和贡献。我会检查总结的要点是否涵盖了文章的关键证明公式,例如流匹配目标函数、条件流匹配目标函数以及高斯条件概率路径的向量场公式等。通过这样的组织和整理,最终生成符合用户指令的答案。
翻译及总结
1. 引言
深度生成模型是一类旨在估计和采样未知数据分布的深度学习算法。近年来,生成模型的显著进展主要得益于基于扩散模型的可扩展和相对稳定的训练。然而,对简单扩散过程的依赖导致了采样概率路径的空间受限,从而导致非常长的训练时间和需要采用专门的方法来实现高效采样。
本文考虑了连续归一化流(CNFs)的一般确定性框架。CNFs能够建模任意的概率路径,并且已知可以涵盖扩散过程所建模的概率路径。然而,除了可以通过去噪分数匹配有效训练的扩散之外,目前还没有可扩展的CNF训练算法。
本文的目标是提出流匹配(FM),这是一种用于训练CNF模型的高效无模拟方法,允许采用一般的概率路径来监督CNF训练。重要的是,FM打破了可扩展CNF训练的障碍,超越了扩散,并绕过了对扩散过程的推理,直接处理概率路径。
2. 预备知识:连续归一化流
设 R d R^{d} Rd表示数据空间,其中数据点 x = ( x 1 , … , x d ) ∈ R d x=\left(x^{1},\ldots, x^{d}\right)\in R^{d} x=(x1,…,xd)∈Rd。本文使用的两个重要对象是:概率密度路径 p : [ 0 , 1 ] × R d → R > 0 p:[0,1]\times R^{d}\rightarrow R>0 p:[0,1]×Rd→R>0,这是一个时间相关的概率密度函数,即 ∫ p t ( x ) d x = 1 \int p_{t}(x) d x=1 ∫pt(x)dx=1;以及时变向量场 v : [ 0 , 1 ] × R d → R d v:[0,1]\times R^{d}\rightarrow R^{d} v:[0,1]×Rd→Rd。向量场 v t v_{t} vt可用于构建时变微分同胚映射,称为流 ϕ : [ 0 , 1 ] × R d → R d \phi:[0,1]\times R^{d}\rightarrow R^{d} ϕ:[0,1]×Rd→Rd,通过常微分方程(ODE)定义:
d d t ϕ t ( x ) = v t ( ϕ t ( x ) ) \frac{d}{d t}\phi_{t}(x)=v_{t}\left(\phi_{t}(x)\right) dtdϕt(x)=vt(ϕt(x))
ϕ 0 ( x ) = x \phi_{0}(x)=x ϕ0(x)=x
Chen等人(2018)建议使用神经网络 v t ( x ; θ ) v_{t}(x;\theta) vt(x;θ)对向量场 v t v_{t} vt进行建模,从而形成连续归一化流(CNF)。CNF用于通过前推方程将简单先验密度 p 0 p_{0} p0重塑为更复杂的密度 p 1 p_{1} p1:
p t = [ ϕ t ] ∗ p 0 p_{t}=\left[\phi_{t}\right]_{*} p_{0} pt=[ϕt]∗p0
其中前推(或变量替换)算子 ∗ * ∗定义为:
[ ϕ t ] ∗ p 0 ( x ) = p 0 ( ϕ t − 1 ( x ) ) det [ ∂ ϕ t − 1 ∂ x ( x ) ] . \left[\phi_{t}\right]_{*} p_{0}(x)=p_{0}\left(\phi_{t}^{-1}(x)\right)\operatorname{det}\left[\frac{\partial\phi_{t}^{-1}}{\partial x}(x)\right]. [ϕt]∗p0(x)=p0(ϕt−1(x))det[∂x∂ϕt−1(x)].
如果向量场 v t v_{t} vt满足方程3,则称其生成概率密度路径 p t p_{t} pt。测试向量场是否生成概率路径的一个实用方法是使用连续性方程,这是我们证明中的一个关键组成部分。
3. 流匹配
设 x 1 x_{1} x1表示根据某个未知数据分布 q ( x 1 ) q\left(x_{1}\right) q(x1)分布的随机变量。假设我们只能访问来自 q ( x 1 ) q\left(x_{1}\right) q(x1)的数据样本,但无法访问密度函数本身。此外,设 p t p_{t} pt是一个概率路径,使得 p 0 = p p_{0}=p p0=p是一个简单分布,例如标准正态分布 p ( x ) = N ( x ∣ 0 , I ) p(x)= N(x|0,I) p(x)=N(x∣0,I),并且 p 1 p_{1} p1近似等于数据分布 q q q。流匹配目标旨在匹配此目标概率路径,从而允许从 p 0 p_{0} p0流向 p 1 p_{1} p1。
给定目标概率密度路径 p t ( x ) p_{t}(x) pt(x)和相应的生成 p t ( x ) p_{t}(x) pt(x)的向量场 u t ( x ) u_{t}(x) ut(x),流匹配(FM)目标定义为:
L F M ( θ ) = E t , p t ( x ) ∥ v t ( x ) − u t ( x ) ∥ 2 , \mathcal{L}_{FM}(\theta)=E_{t, p_t(x)}\left\|v_t(x)-u_t(x)\right\|^2, LFM(θ)=Et,pt(x)∥vt(x)−ut(x)∥2,
其中 θ \theta θ表示CNF向量场 v t v_{t} vt的可学习参数, t ∼ U [ 0 , 1 ] t\sim \mathcal{U}[0,1] t∼U[0,1], x ∼ p t ( x ) x\sim p_{t}(x) x∼pt(x)。简而言之,FM损失通过神经网络 v t v_{t} vt回归向量场 u t u_{t} ut。当损失为零时,学习的CNF模型将生成 p t ( x ) p_{t}(x) pt(x)。
流匹配是一个简单且有吸引力的目标,但单独使用时在实际中是不可行的,因为我们没有任何关于适当 p t p_{t} pt和 u t u_{t} ut的先验知识。有许多选择可以满足 p 1 ( x ) ≈ q ( x ) p_{1}(x)\approx q(x) p1(x)≈q(x)的概率路径,更重要的是,我们通常无法获得生成所需 p t p_{t} pt的封闭形式 u t u_{t} ut。本文展示了如何通过仅定义在样本上的条件概率路径和向量场来构建 p t p_{t} pt和 u t u_{t} ut,并提供了一个更易处理的流匹配目标。
3.1 从条件概率路径和向量场构建 p t , u t p_{t}, u_{t} pt,ut
构建目标概率路径的一种简单方法是通过简单概率路径的混合:给定特定数据样本 x 1 x_{1} x1,我们表示 p t ( x ∣ x 1 ) p_{t}\left(x\mid x_{1}\right) pt(x∣x1)为一个条件概率路径,满足 p 0 ( x ∣ x 1 ) = p ( x ) p_{0}\left(x\mid x_{1}\right)=p(x) p0(x∣x1)=p(x),并在 t = 1 t=1 t=1时设计 p 1 ( x ∣ x 1 ) p_{1}\left(x\mid x_{1}\right) p1(x∣x1)为集中在 x = x 1 x=x_{1} x=x1附近的分布。通过对 q ( x 1 ) q\left(x_{1}\right) q(x1)边缘化条件概率路径,得到边缘概率路径:
p t ( x ) = ∫ p t ( x ∣ x 1 ) q ( x 1 ) d x 1 , p_{t}(x)=\int p_{t}\left(x\mid x_{1}\right) q\left(x_{1}\right) d x_{1}, pt(x)=∫pt(x∣x1)q(x1)dx1,
其中在 t = 1 t=1 t=1时,边缘概率 p 1 p_{1} p1是近似于数据分布 q q q的混合分布:
p 1 ( x ) = ∫ p 1 ( x ∣ x 1 ) q ( x 1 ) d x 1 ≈ q ( x ) . p_{1}(x)=\int p_{1}\left(x\mid x_{1}\right) q\left(x_{1}\right) d x_{1}\approx q(x). p1(x)=∫p1(x∣x1)q(x1)dx1≈q(x).
我们还通过以下方式“边缘化”条件向量场来定义边缘向量场:
u t ( x ) = ∫ u t ( x ∣ x 1 ) p t ( x ∣ x 1 ) q ( x 1 ) p t ( x ) d x u_{t}(x)=\int u_{t}\left(x\mid x_{1}\right)\frac{p_{t}\left(x\mid x_{1}\right) q\left(x_{1}\right)}{p_{t}(x)} d x ut(x)=∫ut(x∣x1)pt(x)pt(x∣x1)q(x1)dx
其中 u t ( ⋅ ∣ x 1 ) u_{t}\left(\cdot\mid x_{1}\right) ut(⋅∣x1)是生成 p t ( ⋅ ∣ x 1 ) p_{t}\left(\cdot\mid x_{1}\right) pt(⋅∣x1)的条件向量场。我们的第一个关键观察是:
边缘向量场(方程8)生成边缘概率路径(方程6)。
这提供了条件VF(生成条件概率路径的那些)和边缘VF(生成边缘概率路径的那些)之间的意外联系。这使得我们可以将未知且难以处理的边缘VF分解为更简单的条件VF,这些条件VF仅依赖于单个数据样本。
3.2 条件流匹配
由于边缘概率路径和VF(方程6和8)中的积分不可行,因此仍然难以计算 u t u_{t} ut,从而难以计算原始流匹配目标的非偏估计量。相反,我们提出了一个更简单的目标,即条件流匹配(CFM)目标:
L CFM ( θ ) = E t , q ( x 1 ) , p t ( x ∣ x 1 ) ∥ v t ( x ) − u t ( x ∣ x 1 ) ∥ 2 , \mathcal{L}_{\text{CFM}}(\theta)=E_{t, q\left(x_{1}\right), p_{t}\left(x\mid x_{1}\right)}\left\|v_{t}(x)-u_{t}\left(x\mid x_{1}\right)\right\|^{2}, LCFM(θ)=Et,q(x1),pt(x∣x1)∥vt(x)−ut(x∣x1)∥2,
其中 t ∼ U [ 0 , 1 ] , x 1 ∼ q ( x 1 ) t\sim\mathcal{U}[0,1], x_{1}\sim q\left(x_{1}\right) t∼U[0,1],x1∼q(x1),现在 x ∼ p t ( x ∣ x 1 ) x\sim p_{t}\left(x\mid x_{1}\right) x∼pt(x∣x1)。与FM目标不同,只要可以从 p t ( x ∣ x 1 ) p_{t}\left(x\mid x_{1}\right) pt(x∣x1)有效采样并计算 u t ( x ∣ x 1 ) u_{t}\left(x\mid x_{1}\right) ut(x∣x1),CFM目标就可以轻松获得无偏估计量。因此,优化CFM目标在期望上等价于优化FM目标,从而允许我们训练CNF生成边缘概率路径 p t p_{t} pt,而无需访问边缘概率路径或边缘向量场。
4. 条件概率路径和向量场
条件流匹配目标可以与任何选择的条件概率路径和条件向量场一起使用。本文讨论了一般高斯条件概率路径族的构建,即考虑形式为:
p t ( x ∣ x 1 ) = N ( x ∣ μ t ( x 1 ) , σ t ( x 1 ) 2 I ) , p_{t}\left(x\mid x_{1}\right)=\mathcal{N}\left(x\mid\mu_{t}\left(x_{1}\right),\sigma_{t}\left(x_{1}\right)^{2} I\right), pt(x∣x1)=N(x∣μt(x1),σt(x1)2I),
其中 μ : [ 0 , 1 ] × R d → R d \mu:[0,1]\times R^{d}\rightarrow R^{d} μ:[0,1]×Rd→Rd是高斯分布的时间相关均值,而 σ : [ 0 , 1 ] × R → R > 0 \sigma:[0,1]\times R\rightarrow R_{>0} σ:[0,1]×R→R>0描述时间相关的标量标准差。我们设置 μ 0 ( x 1 ) = 0 \mu_{0}\left(x_{1}\right)=0 μ0(x1)=0和 σ 0 ( x 1 ) = 1 \sigma_{0}\left(x_{1}\right)=1 σ0(x1)=1,使得所有条件概率路径在 t = 0 t=0 t=0时收敛到相同的标准高斯噪声分布 p ( x ) = N ( x ∣ 0 , I ) p(x)=\mathcal{N}(x|0, I) p(x)=N(x∣0,I)。然后设置 μ 1 ( x 1 ) = x 1 \mu_{1}\left(x_{1}\right)=x_{1} μ1(x1)=x1和 σ 1 ( x 1 ) = σ min \sigma_{1}\left(x_{1}\right)=\sigma_{\min} σ1(x1)=σmin,使得 p 1 ( x ∣ x 1 ) p_{1}\left(x\mid x_{1}\right) p1(x∣x1)是在 x 1 x_{1} x1附近集中的高斯分布。
存在无数生成特定概率路径的向量场,但大多数向量场由于包含保持基础分布不变的成分而导致不必要的额外计算。我们决定使用对应于高斯分布的正则变换的最简单向量场。具体来说,考虑条件流(以 x 1 x_{1} x1为条件):
ψ t ( x ) = σ t ( x 1 ) x + μ t ( x 1 ) . \psi_t(x)=\sigma_t\left(x_1\right) x+\mu_t\left(x_1\right). ψt(x)=σt(x1)x+μt(x1).
当 x x x服从标准高斯分布时, ψ t ( x ) \psi_t(x) ψt(x)是将噪声分布 p 0 ( x ∣ x 1 ) = p ( x ) p_{0}\left(x\mid x_{1}\right)=p(x) p0(x∣x1)=p(x)映射到具有均值 μ t ( x 1 ) \mu_{t}\left(x_{1}\right) μt(x1)和标准差 σ t ( x 1 ) \sigma_{t}\left(x_{1}\right) σt(x1)的高斯分布的仿射变换。即,根据方程4, ψ t \psi_{t} ψt将噪声分布推前到 p t ( x ∣ x 1 ) p_{t}\left(x\mid x_{1}\right) pt(x∣x1):
[ ψ t ] ∗ p ( x ) = p t ( x ∣ x 1 ) . \left[\psi_{t}\right]*p(x)=p_{t}\left(x\mid x_{1}\right). [ψt]∗p(x)=pt(x∣x1).
此流提供了一个生成条件概率路径的向量场:
d d t ψ t ( x ) = u t ( ψ t ( x ) ∣ x 1 ) \frac{d}{d t}\psi_{t}(x)=u_{t}\left(\psi_{t}(x)\mid x_{1}\right) dtdψt(x)=ut(ψt(x)∣x1)
重新参数化 p t ( x ∣ x 1 ) p_{t}\left(x\mid x_{1}\right) pt(x∣x1)仅涉及 x 0 x_{0} x0并将方程13代入CFM损失,我们得到:
L CFM ( θ ) = E t , q ( x 1 ) , p ( x 0 ) ∥ v t ( ψ t ( x 0 ) ) − d d t ψ t ( x 0 ) ∥ 2 . \mathcal{L}_{\text{{CFM}}}(\theta)=E_{t,q(x_{1}),p(x_{0})}\left\|v_{t}\left(\psi_{t}\left(x_{0}\right)\right)-\frac{d}{dt}\psi_{t}\left(x_{0}\right)\right\|^{2}. LCFM(θ)=Et,q(x1),p(x0) vt(ψt(x0))−dtdψt(x0) 2.
由于 ψ t \psi_{t} ψt是一个简单的可逆仿射映射,我们可以用封闭形式求解 u t u_{t} ut。设 f ′ f^{\prime} f′表示对时间的导数,即 f ′ = d d t f f^{\prime}=\frac{d}{d t} f f′=dtdf。
定理3. 设 p t ( x ∣ x 1 ) p_{t}\left(x\mid x_{1}\right) pt(x∣x1)为如方程10所述的高斯概率路径, ψ t \psi_{t} ψt为其对应的流映射。则定义 ψ t \psi_{t} ψt的唯一向量场具有形式:
u t ( x ∣ x 1 ) = σ t ′ ( x 1 ) σ t ( x 1 ) ( x − μ t ( x 1 ) ) + μ t ′ ( x 1 ) . u_{t}\left(x\mid x_{1}\right)=\frac{\sigma_{t}^{\prime}\left(x_{1}\right)}{\sigma_{t}\left(x_{1}\right)}\left(x-\mu_{t}\left(x_{1}\right)\right)+\mu_{t}^{\prime}\left(x_{1}\right). ut(x∣x1)=σt(x1)σt′(x1)(x−μt(x1))+μt′(x1).
因此, u t ( x ∣ x 1 ) u_{t}\left(x\mid x_{1}\right) ut(x∣x1)生成高斯路径 p t ( x ∣ x 1 ) p_{t}\left(x\mid x_{1}\right) pt(x∣x1)。
4.1 高斯条件概率路径的特殊实例
我们的公式对于任意满足所需边界条件的函数 μ t ( x 1 ) \mu_{t}\left(x_{1}\right) μt(x1)和 σ t ( x 1 ) \sigma_{t}\left(x_{1}\right) σt(x1)是完全通用的。首先讨论恢复先前使用的扩散过程的概率路径的特殊情况。由于我们直接处理概率路径,因此可以完全放弃对扩散过程的推理。
示例I:扩散条件VF。扩散模型从数据点开始,逐渐添加噪声直到近似于纯噪声。这些可以表示为随机过程,其在任意时间 t t t具有严格的闭合形式表示要求,导致具有特定均值 μ t ( x 1 ) \mu_{t}\left(x_{1}\right) μt(x1)和标准差 σ t ( x 1 ) \sigma_{t}\left(x_{1}\right) σt(x1)的高斯条件概率路径 p t ( x ∣ x 1 ) p_{t}\left(x\mid x_{1}\right) pt(x∣x1)。例如,反向(噪声 → \rightarrow →数据)方差爆炸(VE)路径的形式为:
p t ( x ) = N ( x ∣ x 1 , σ 1 − t 2 I ) , p_{t}(x)=\mathcal{N}\left(x\mid x_{1},\sigma_{1-t}^{2} I\right), pt(x)=N(x∣x1,σ1−t2I),
其中 σ t \sigma_{t} σt是递增函数, σ 0 = 0 \sigma_{0}=0 σ0=0,且 σ 1 ≫ 1 \sigma_{1}\gg 1 σ1≫1。将这些选择代入定理3的方程15,我们得到:
u t ( x ∣ x 1 ) = − σ 1 − t ′ σ 1 − t ( x − x 1 ) . u_{t}\left(x\mid x_{1}\right)=-\frac{\sigma_{1-t}^{\prime}}{\sigma_{1-t}}\left(x-x_{1}\right). ut(x∣x1)=−σ1−tσ1−t′(x−x1).
反向(噪声 → \rightarrow →数据)方差保持(VP)扩散路径的形式为:
p t ( x ∣ x 1 ) = N ( x ∣ α 1 − t x 1 , ( 1 − α 1 − t 2 ) I ) , where α t = e − 1 2 T ( t ) , T ( t ) = ∫ 0 t β ( s ) d s , p_{t}\left(x\mid x_{1}\right)=\mathcal{N}\left(x\mid\alpha_{1-t} x_{1},\left(1-\alpha_{1-t}^{2}\right) I\right),\text{ where}\alpha_{t}=e^{-\frac{1}{2} T(t)}, T(t)=\int_{0}^{t}\beta(s) d s, pt(x∣x1)=N(x∣α1−tx1,(1−α1−t2)I), whereαt=e−21T(t),T(t)=∫0tβ(s)ds,
以及 β \beta β是噪声尺度函数。将这些选择代入定理3的方程15,我们得到:
u t ( x ∣ x 1 ) = α 1 − t ′ 1 − α 1 − t 2 ( α 1 − t x − x 1 ) = − T ′ ( 1 − t ) 2 [ e − T ( 1 − t ) x − e − 1 2 T ( 1 − t ) x 1 1 − e − T ( 1 − t ) ] . u_{t}\left(x\mid x_{1}\right)=\frac{\alpha_{1-t}^{\prime}}{1-\alpha_{1-t}^{2}}\left(\alpha_{1-t} x-x_{1}\right)=-\frac{T^{\prime}(1-t)}{2}\left[\frac{e^{-T(1-t)} x-e^{-\frac{1}{2} T(1-t)} x_{1}}{1-e^{-T(1-t)}}\right]. ut(x∣x1)=1−α1−t2α1−t′(α1−tx−x1)=−2T′(1−t)[1−e−T(1−t)e−T(1−t)x−e−21T(1−t)x1].
我们的条件VF u t ( x ∣ x 1 ) u_{t}\left(x\mid x_{1}\right) ut(x∣x1)的构建实际上与确定性概率流中先前使用的向量场一致。
示例II:最优传输条件VF。一种更自然的条件概率路径选择是将均值和标准差简单地线性变化,即:
μ t ( x ) = t x 1 , and σ t ( x ) = 1 − ( 1 − σ min ) t . \mu_{t}(x)=t x_{1},\text{ and}\sigma_{t}(x)=1-\left(1-\sigma_{\min}\right) t. μt(x)=tx1, andσt(x)=1−(1−σmin)t.
根据定理3,此路径由VF生成:
u t ( x ∣ x 1 ) = x 1 − ( 1 − σ min ) x 1 − ( 1 − σ min ) t , u_{t}\left(x\mid x_{1}\right)=\frac{x_{1}-\left(1-\sigma_{\min}\right) x}{1-\left(1-\sigma_{\min}\right) t}, ut(x∣x1)=1−(1−σmin)tx1−(1−σmin)x,
这与扩散条件VF(方程19)相比,对于所有 t ∈ [ 0 , 1 ] t\in[0,1] t∈[0,1]都有定义。对应于 u t ( x ∣ x 1 ) u_{t}\left(x\mid x_{1}\right) ut(x∣x1)的条件流为:
ψ t ( x ) = ( 1 − ( 1 − σ min ) t ) x + t x 1 , \psi_{t}(x)=\left(1-\left(1-\sigma_{\min}\right) t\right) x+t x_{1}, ψt(x)=(1−(1−σmin)t)x+tx1,
在这种情况下,CFM损失(见方程9、14)的形式为:
L CFM ( θ ) = E t , q ( x 1 ) , p ( x 0 ) ∥ v t ( ψ t ( x 0 ) ) − ( x 1 − ( 1 − σ min ) x 0 ) ∥ 2 . \mathcal{L}_{\text{CFM}}(\theta)=E_{t, q\left(x_{1}\right), p\left(x_{0}\right)}\left\|v_{t}\left(\psi_{t}\left(x_{0}\right)\right)-\left(x_{1}-\left(1-\sigma_{\min}\right) x_{0}\right)\right\|^{2}. LCFM(θ)=Et,q(x1),p(x0)∥vt(ψt(x0))−(x1−(1−σmin)x0)∥2.
允许均值和标准差线性变化不仅导致简单直观的路径,而且在某种意义上是最佳的。条件流 ψ t ( x ) \psi_{t}(x) ψt(x)实际上是两个高斯分布之间的最优传输(OT)位移映射。OT插值器定义为:
p t = [ ( 1 − t ) i d + t ψ ] ⋆ p 0 p_t=[(1-t) id+t\psi]_{\star} p_0 pt=[(1−t)id+tψ]⋆p0
其中 ψ : R d → R d \psi: R^{d}\rightarrow R^{d} ψ:Rd→Rd是将 p 0 p_{0} p0推前到 p 1 p_{1} p1的OT映射,id表示恒等映射,即 i d ( x ) = x id(x)=x id(x)=x, ( 1 − t ) i d + t ψ (1-t) id+t\psi (1−t)id+tψ称为OT位移映射。示例1.7显示,在我们的情况下,两个高斯分布中的第一个是标准高斯分布,OT位移映射采用方程22的形式。
直观地说,OT位移映射下的粒子总是沿直线轨迹以恒定速度移动。图3描绘了扩散和OT条件VF的采样路径。有趣的是,我们发现扩散路径的采样轨迹可能在最后“过冲”,导致不必要的回溯,而OT路径保证保持直线。
图2比较了扩散条件得分函数(典型扩散方法中的回归目标)与OT条件VF。起始和结束的高斯分布在这两种情况下是相同的。一个有趣的观察是,OT VF在时间上具有恒定的方向,这可能使回归任务更简单。此属性也可以直接从方程21验证,因为VF可以写成 u t ( x ∣ x 1 ) = g ( t ) h ( x ∣ x 1 ) u_{t}\left(x\mid x_{1}\right)=g(t) h\left(x\mid x_{1}\right) ut(x∣x1)=g(t)h(x∣x1)的形式。
5. 实验
我们在CIFAR-10和ImageNet数据集上探索了使用流匹配的经验优势。我们还消融了流匹配中扩散路径的选择,特别是标准方差保持扩散路径和最优传输路径之间的选择。讨论了通过直接参数化生成向量场和使用流匹配目标如何改进样本生成。最后,展示了流匹配也可用于条件生成设置。
表1总结了我们在不同方法上训练的相同模型的结果,报告了负对数似然(NLL)、样本质量(FID)和评估时间(NFE)。所有模型使用相同的架构、超参数值和训练迭代次数进行训练,基线允许更多的迭代次数以实现更好的收敛。
6. 结论
本文介绍了流匹配,这是一种用于训练连续归一化流模型的新无模拟框架,依赖于条件构建以轻松扩展到非常高维度。此外,FM框架提供了对扩散模型的另一种观点,并建议放弃随机/扩散构建,转而直接指定概率路径,从而允许构造允许更快采样和/或改进生成的路径。实验表明,使用流匹配框架进行训练和采样非常容易,并且在未来,预计FM将为允许多种概率路径打开大门。