Skip to content
Go back

MIT 6.S184 学习笔记:Diffusion 与 Flow Matching

Edit page
94 min read

MIT 6.S184 学习笔记:Diffusion 与 Flow Matching

0. A Reminder on Probability Theory

后面所有 flow matching 和 diffusion 的公式,本质上都在操作概率分布。因此先把 Appendix A 里的概率论记号放在最前面。这里不是为了完整复习概率论,而是为了固定这门课会反复使用的几个对象。

0.1 Random vectors and PDFs

课程里主要处理连续向量:

x=(x1,,xd)Rd.x=(x^1,\dots,x^d)\in\mathbb R^d.

一个随机向量记作

XRd.X\in\mathbb R^d.

如果 XX 有连续概率密度函数,也就是 PDF:

pX:RdR0,p_X:\mathbb R^d\to\mathbb R_{\ge 0},

那么事件 XAX\in A 的概率为

P(XA)=ApX(x)dx.\mathbb P(X\in A) = \int_A p_X(x)\,dx.

概率密度必须归一化:

pX(x)dx=1.\int p_X(x)\,dx=1.

如果积分区域没有特别写,默认是在整个空间 Rd\mathbb R^d 上积分。

随机变量 XtX_t 的 density 经常简写成

pt.p_t.

也就是说,当我们写

XtptX_t\sim p_t

时,意思是 XtX_t 的分布密度是 ptp_t

0.2 Gaussian distribution

生成模型里最常见的简单分布是 isotropic Gaussian:

N(x;μ,σ2I)=(2πσ2)d2exp(xμ22σ2).\mathcal N(x;\mu,\sigma^2I) = (2\pi\sigma^2)^{-\frac d2} \exp\left( - \frac{\|x-\mu\|^2}{2\sigma^2} \right).

其中:

后面 Gaussian probability path 会反复使用这个形式:

pt(xz)=N(x;αtz,βt2Id).p_t(x\mid z) = \mathcal N(x;\alpha_tz,\beta_t^2I_d).

它的含义就是:在给定 clean data zz 的情况下,xx 是一个以 αtz\alpha_tz 为均值、噪声尺度为 βt\beta_t 的高斯随机变量。

0.3 Expectation and LOTUS

随机向量的期望是

E[X]=xpX(x)dx.\mathbb E[X] = \int x\,p_X(x)\,dx.

它也可以被理解成在 least-squares sense 下最接近随机变量 XX 的常数向量:

E[X]=argminzRdxz2pX(x)dx.\mathbb E[X] = \arg\min_{z\in\mathbb R^d} \int \|x-z\|^2p_X(x)\,dx.

如果要计算随机变量函数的期望,可以直接用 density 积分:

E[f(X)]=f(x)pX(x)dx.\mathbb E[f(X)] = \int f(x)p_X(x)\,dx.

这叫 law of the unconscious statistician,简称 LOTUS。后面各种 loss 都是这种形式,例如

Et,z,ϵ[utθ(αtz+βtϵ)(α˙tz+β˙tϵ)2].\mathbb E_{t,z,\epsilon} \left[ \left\| u_t^\theta(\alpha_tz+\beta_t\epsilon) - (\dot\alpha_tz+\dot\beta_t\epsilon) \right\|^2 \right].

它的意思就是:按指定方式采样 t,z,ϵt,z,\epsilon,再对括号里的函数取平均。

0.4 Joint density and marginals

如果有两个随机变量 X,YX,Y,它们的 joint PDF 记作

pX,Y(x,y).p_{X,Y}(x,y).

joint density 描述的是 (X,Y)(X,Y) 同时取某组值的密度。只关心其中一个变量时,需要把另一个变量积分掉:

pX(x)=pX,Y(x,y)dy,p_X(x) = \int p_{X,Y}(x,y)\,dy, pY(y)=pX,Y(x,y)dx.p_Y(y) = \int p_{X,Y}(x,y)\,dx.

这两个分布叫 marginals。

Joint PDF and marginals

图 21 中,橙色阴影是 joint PDF:

pX,Y(x,y).p_{X,Y}(x,y).

上方黑线是把 yy 积掉后得到的 pX(x)p_X(x),右侧黑线是把 xx 积掉后得到的 pY(y)p_Y(y)

这和 flow matching 里的 marginal path 是同一个思想。我们有联合采样:

zpdata,xpt(z),z\sim p_{\text{data}}, \qquad x\sim p_t(\cdot\mid z),

它对应联合密度

pt(xz)pdata(z).p_t(x\mid z)p_{\text{data}}(z).

zz 积掉,就得到 xx 的 marginal:

pt(x)=pt(xz)pdata(z)dz.p_t(x) = \int p_t(x\mid z)p_{\text{data}}(z)\,dz.

0.5 Conditional density and Bayes’ rule

conditional density 定义为

pXY(xy)=pX,Y(x,y)pY(y),p_{X\mid Y}(x\mid y) = \frac{p_{X,Y}(x,y)}{p_Y(y)},

其中要求 pY(y)>0p_Y(y)>0

它表示:在已经知道 Y=yY=y 的情况下,XX 的分布。

Bayes’ rule 把反方向的条件分布写出来:

pYX(yx)=pXY(xy)pY(y)pX(x).p_{Y\mid X}(y\mid x) = \frac{p_{X\mid Y}(x\mid y)p_Y(y)}{p_X(x)}.

这个公式在 flow matching 和 score matching 中非常重要。比如给定 noisy sample xx 后,背后的 clean data zz 的 posterior 是

pt(zx)=pt(xz)pdata(z)pt(x).p_t(z\mid x) = \frac{p_t(x\mid z)p_{\text{data}}(z)}{p_t(x)}.

这个权重反复出现在:

uttarget(x)=uttarget(xz)pt(xz)pdata(z)pt(x)dz,u_t^{\text{target}}(x) = \int u_t^{\text{target}}(x\mid z) \frac{p_t(x\mid z)p_{\text{data}}(z)}{p_t(x)} \,dz,

以及

logpt(x)=logpt(xz)pt(xz)pdata(z)pt(x)dz.\nabla\log p_t(x) = \int \nabla\log p_t(x\mid z) \frac{p_t(x\mid z)p_{\text{data}}(z)}{p_t(x)} \,dz.

也就是说,marginal velocity 和 marginal score 都是 conditional 对象按 posterior 权重做平均。

0.6 Conditional expectation

conditional expectation 可以先看成一个函数:

E[XY=y]=xpXY(xy)dx.\mathbb E[X\mid Y=y] = \int x\,p_{X\mid Y}(x\mid y)\,dx.

它表示:已知 Y=yY=y 后,对 XX 的平均估计。

如果把 yy 换成随机变量 YY 本身,就得到另一个随机变量:

E[XY].\mathbb E[X\mid Y].

这两个记号容易混:

conditional expectation 还有一个 least-squares 解释:

E[XY]\mathbb E[X\mid Y]

是所有只依赖 YY 的函数里,对 XX 做均方误差预测的最优函数。这一点正是 Theorem 12 背后的统计直觉:网络只看到 noisy sample xx,而 target 还依赖隐藏的 zz,MSE 下最优预测就是对 zxz\mid x 的条件平均。

0.7 Tower property

tower property 是

E[E[XY]]=E[X].\mathbb E[\mathbb E[X\mid Y]] = \mathbb E[X].

意思是:先在给定 YY 的情况下对 XX 取平均,再对 YY 的随机性取平均,结果等于直接对 XX 取平均。

用 density 写就是

E[E[XY]]=(xpXY(xy)dx)pY(y)dy=xpX(x)dx.\mathbb E[\mathbb E[X\mid Y]] = \int \left( \int x\,p_{X\mid Y}(x\mid y)\,dx \right) p_Y(y)\,dy = \int x\,p_X(x)\,dx.

后面只要看到“先条件平均,再整体平均”,本质上都在使用这个思想。

最后,对任意函数 f(X,Y)f(X,Y),有

E[f(X,Y)Y=y]=f(x,y)pXY(xy)dx.\mathbb E[f(X,Y)\mid Y=y] = \int f(x,y)p_{X\mid Y}(x\mid y)\,dx.

这会帮助理解很多带条件的 loss 和 posterior average。

这一节最重要的连接是:

这门课里的 marginalization trick、score posterior average、denoiser、Theorem 12,本质上都在反复使用 joint density、marginal density、Bayes rule 和 conditional expectation。


1. Generative Modeling as Sampling

生成模型首先要解决的并不是“如何画出一张图片”,而是更抽象的问题:如何从一个分布中采样。图像、视频、分子结构看起来是不同的数据类型,但都可以统一看成某个高维空间中的向量:

zRd.z \in \mathbb R^d.

例如,一张 RGB 图片可以看成 H×W×3H \times W \times 3 个数,一个视频可以看成一串图片,一个分子结构也可以用原子坐标组成的向量表示。这样一来,“生成对象”就变成了“生成一个向量”。

接下来我们需要描述什么叫“生成得好”。以生成一张狗的图片为例,并不存在唯一正确的狗图。合理的狗图有很多种,因此我们用一个概率分布来描述这些可能对象的集合。这个分布记作

pdata.p_{\text{data}}.

于是生成任务就被形式化为:

zpdata.z \sim p_{\text{data}}.

这句话是整个课程的起点。我们并不知道 pdatap_{\text{data}} 的显式表达式;训练时真正拥有的是有限个样本:

z1,,zNpdata.z_1,\dots,z_N \sim p_{\text{data}}.

也就是说,数据集只是这个未知分布的有限观察。生成模型的目标,是利用这些样本构造一个算法,使它训练后能够产生新的、近似来自 pdatap_{\text{data}} 的样本。

如果加入条件,例如文本 prompt yy,问题就变成条件采样:

zpdata(y).z \sim p_{\text{data}}(\cdot \mid y).

这就是 guided generation 或 conditional generation。前面先讨论 unconditional case,因为条件生成的很多方法是在无条件生成的基础上扩展出来的。

Summary 2

Summary 2 的中文整理

本节的结论可以概括为四点。第一,本课程主要研究可以表示成向量的对象,例如图像、视频和分子结构:

zRd.z \in \mathbb R^d.

第二,生成就是从数据分布中采样:

zpdata.z \sim p_{\text{data}}.

第三,如果有条件变量 yy,则目标变成:

zpdata(y).z \sim p_{\text{data}}(\cdot \mid y).

第四,我们的目标是构造一个 generative model,使它训练完成后能够返回来自 pdatap_{\text{data}} 的样本,或者至少返回近似来自这个分布的样本。

到这里,生成模型的问题已经从直观任务变成了一个数学任务:给定样本,学习如何从未知数据分布中采样。


2. Flow and Diffusion Models

2.1 Flow Models: 用 ODE 搬运分布

采样算法可以通过动力系统来构造:从一个简单分布开始,让样本沿着某种运动规则演化,最后希望它们落在数据分布上。

我们先从确定性的动力系统开始,也就是 ODE。

一个轨迹是时间到空间位置的函数:

X:[0,1]Rd,tXt.X : [0,1] \to \mathbb R^d, \qquad t \mapsto X_t.

这里 XtX_t 表示粒子在时间 tt 的位置。要规定这个粒子如何运动,我们引入向量场:

u:Rd×[0,1]Rd,(x,t)ut(x).u : \mathbb R^d \times [0,1] \to \mathbb R^d, \qquad (x,t) \mapsto u_t(x).

向量场的含义是:在时间 tt、位置 xx,粒子的瞬时速度应该是 ut(x)u_t(x)。于是轨迹满足的 ODE 写成

ddtXt=ut(Xt),X0=x0.\frac{d}{dt}X_t = u_t(X_t), \qquad X_0 = x_0.

这条公式的意思很直接:轨迹在每个时刻都沿着向量场给出的方向前进。

与 ODE 对应的 flow map 记作

ψt(x0).\psi_t(x_0).

它回答的问题是:如果粒子从 x0x_0 出发,按照这个 ODE 走到时间 tt,它会在哪里?因此

Xt=ψt(X0).X_t = \psi_t(X_0).

所以,vector field、ODE 和 flow 是同一个动力系统的三种描述方式:向量场规定局部速度,ODE 规定轨迹如何跟随这个速度,flow 则描述初始点被整体搬运到哪里。

Flow vector field

图中蓝色箭头是向量场,红色网格表示空间被 flow 扭曲后的样子。这个图很重要,因为它提醒我们:flow model 不是在单独移动一个点,而是在整体搬运空间里的概率质量。

一个简单例子是线性向量场:

ut(x)=θx,θ>0.u_t(x) = -\theta x, \qquad \theta > 0.

对应的 ODE 为

ddtXt=θXt,\frac{d}{dt}X_t = -\theta X_t,

它的 flow 是

ψt(x0)=exp(θt)x0.\psi_t(x_0) = \exp(-\theta t)x_0.

这个例子可以这样理解:点离原点越远,被拉回原点的速度越大;随着时间增加,所有点都会指数式靠近 0。后面 OU process 正是在这个稳定 drift 上加入 Brownian noise。

一般情况下,我们很难显式写出 ψt\psi_t,因此需要数值模拟。最简单的是 Euler method:

Xt+h=Xt+hut(Xt).X_{t+h} = X_t + h u_t(X_t).

也就是每次沿当前速度走一小步。Heun’s method 则先用 Euler 猜一步,再用当前点和预测点的速度平均来修正:

Xt+h=Xt+hut(Xt),X'_{t+h} = X_t + h u_t(X_t), Xt+h=Xt+h2(ut(Xt)+ut+h(Xt+h)).X_{t+h} = X_t + \frac h2\left(u_t(X_t)+u_{t+h}(X'_{t+h})\right).

现在可以定义 flow model。ODE 本身是确定性的,如果初始点固定,终点也固定。但生成模型需要随机性,所以随机性放在初始条件中:

X0pinit.X_0 \sim p_{\text{init}}.

通常取

pinit=N(0,Id).p_{\text{init}} = \mathcal N(0,I_d).

然后用神经网络参数化向量场:

ddtXt=utθ(Xt).\frac{d}{dt}X_t = u_t^\theta(X_t).

我们的目标是让终点分布满足

X1pdata.X_1 \sim p_{\text{data}}.

这里要特别注意:神经网络学习的是 utθu_t^\theta,也就是速度场;flow ψtθ\psi_t^\theta 不是直接给出的,而是通过模拟 ODE 得到的。


2.2 Diffusion Models: 从 ODE 到 SDE

Flow model 使用确定性 ODE。Diffusion model 则把确定性轨迹推广成随机轨迹,也就是 SDE。

一个随机过程写作

(Xt)0t1.(X_t)_{0\le t\le 1}.

这意味着每个 XtX_t 都是随机变量,而且同一个过程模拟两次,可能得到两条不同的轨迹。SDE 的随机性来自 Brownian motion,记作 WtW_t。可以把 Brownian motion 理解成连续时间随机游走。

Brownian motion

Brownian motion 有两个性质非常关键。第一是 normal increments:

WtWsN(0,(ts)Id),0s<t.W_t - W_s \sim \mathcal N(0,(t-s)I_d), \qquad 0 \le s < t.

这说明它的增量均值为 0,方差随时间长度线性增长。因此在长度为 hh 的小时间步内,随机扰动的尺度是 h\sqrt h,这就是后面 Euler-Maruyama 公式里 h\sqrt h 的来源。

第二是 independent increments:不重叠时间区间上的增量相互独立。直观地说,过去的随机波动不会预示未来的随机波动。

因此 Brownian motion 可以近似模拟为

Wt+h=Wt+hϵt,ϵtN(0,Id).W_{t+h} = W_t + \sqrt h\,\epsilon_t, \qquad \epsilon_t \sim \mathcal N(0,I_d).

为了把随机扰动加入动力系统,我们先把 ODE 改写成小步更新形式。这样做的原因是 Brownian motion 的轨迹虽然连续,但非常粗糙,不能像普通光滑函数那样直接求导。

ODE derivative to update

ODE

ddtXt=ut(Xt)\frac{d}{dt}X_t = u_t(X_t)

可以等价地理解为小步更新:

Xt+h=Xt+hut(Xt)+hRt(h),X_{t+h}=X_t+h u_t(X_t)+hR_t(h),

其中当 h0h\to 0 时,Rt(h)R_t(h) 可以忽略。SDE 在这个小步更新中加入 Brownian motion 的随机增量:

SDE update notation

Xt+h=Xt+hut(Xt)+σt(Wt+hWt)+hRt(h).X_{t+h} = X_t + h u_t(X_t) + \sigma_t(W_{t+h}-W_t) + hR_t(h).

符号上通常写成

dXt=ut(Xt)dt+σtdWt.dX_t = u_t(X_t)\,dt + \sigma_t\,dW_t.

这里 ut(Xt)dtu_t(X_t)\,dt 是 deterministic drift,σtdWt\sigma_t\,dW_t 是 stochastic diffusion,σt\sigma_t 是 diffusion coefficient。

Ornstein-Uhlenbeck process 是理解这类 SDE 的一个标准例子:

OU process

Example 6

OU process 的 SDE 是

dXt=θXtdt+σdWt.dX_t = -\theta X_t\,dt + \sigma\,dW_t.

其中 θXtdt-\theta X_t\,dt 把系统拉回 0,σdWt\sigma\,dW_t 持续注入随机性。若 σ=0\sigma=0,它就退化为前面学过的线性 ODE flow;若 σ>0\sigma>0,路径会随机抖动,但整体仍被 drift 拉回中心,并在长时间后趋向稳定高斯分布:

N(0,σ22θ).\mathcal N\left(0,\frac{\sigma^2}{2\theta}\right).

所以 OU process 可以看成两个机制的平衡:drift 提供稳定性,noise 提供随机性。

为了模拟 SDE,我们使用 Euler-Maruyama method。由于

Wt+hWtN(0,hId),W_{t+h}-W_t \sim \mathcal N(0,hI_d),

可以写成

Wt+hWt=hϵt,ϵtN(0,Id).W_{t+h}-W_t = \sqrt h\,\epsilon_t, \qquad \epsilon_t \sim \mathcal N(0,I_d).

代入 SDE 的小步形式,得到

Xt+h=Xt+hut(Xt)+σthϵt.X_{t+h} = X_t + h u_t(X_t) + \sigma_t \sqrt h\,\epsilon_t.

这就是 Euler-Maruyama method。

Algorithm 2

现在可以定义 diffusion model。它由一个神经网络 drift 和一个固定 diffusion coefficient 组成:

X0pinit,X_0 \sim p_{\text{init}}, dXt=utθ(Xt)dt+σtdWt.dX_t = u_t^\theta(X_t)\,dt + \sigma_t\,dW_t.

其中 utθu_t^\theta 是要学习的向量场,σt\sigma_t 是预先指定的噪声强度。采样时从 pinitp_{\text{init}} 出发,模拟这个 SDE 到 t=1t=1,希望得到

X1pdata.X_1 \sim p_{\text{data}}.

Summary 7

Summary 7 的中文整理

在这里,一个 diffusion model 由两部分构成。神经网络

uθ:Rd×[0,1]Rd,(x,t)utθ(x),u^\theta : \mathbb R^d \times [0,1] \to \mathbb R^d, \qquad (x,t)\mapsto u_t^\theta(x),

参数化 drift / vector field;固定函数

σt:[0,1][0,)\sigma_t : [0,1]\to [0,\infty)

控制随机噪声强度。生成时先采样

X0pinit,X_0 \sim p_{\text{init}},

再模拟

dXt=utθ(Xt)dt+σtdWt,dX_t = u_t^\theta(X_t)\,dt+\sigma_t\,dW_t,

目标是让

X1pdata.X_1 \sim p_{\text{data}}.

σt=0\sigma_t=0 时,随机项消失,diffusion model 就变成 flow model。因此 flow model 是 diffusion model 的一个特殊情况。


3. Flow Matching:从分布路径到可训练的速度场

到这里为止,我们已经知道如何用 ODE 或 SDE 定义生成模型。两者的共同目标都是

X0pinit,X1pdata.X_0 \sim p_{\text{init}}, \qquad X_1 \sim p_{\text{data}}.

区别在于 flow model 使用确定性 ODE:

dXt=utθ(Xt)dt,dX_t = u_t^\theta(X_t)\,dt,

而 diffusion model 使用带随机噪声的 SDE:

dXt=utθ(Xt)dt+σtdWt.dX_t = u_t^\theta(X_t)\,dt+\sigma_t\,dW_t.

但这时还有一个核心问题没有回答:神经网络 utθu_t^\theta 应该如何训练?

Flow Matching 先把问题限制在 flow model 中,也就是先不考虑随机噪声:

X0pinit,dXt=utθ(Xt)dt.X_0 \sim p_{\text{init}}, \qquad dX_t = u_t^\theta(X_t)\,dt.

我们希望训练后

X1pdata.X_1 \sim p_{\text{data}}.

3.1 Conditional and Marginal Probability Path

现在先回到 flow model。模型从简单初始分布采样:

X0pinit,X_0 \sim p_{\text{init}},

然后沿着神经网络给出的速度场演化:

dXt=utθ(Xt)dt.dX_t = u_t^\theta(X_t)\,dt.

最后用 t=1t=1 的位置作为生成样本。我们希望这个终点满足

X1pdata.X_1 \sim p_{\text{data}}.

训练问题因此可以写得非常直接:怎样优化参数 θ\theta,使得模拟这个 ODE 之后,终点分布就是数据分布?Flow Matching 的回答并不是直接处理终点,而是先安排整个中间过程。也就是说,我们不只关心

t=0andt=1,t=0 \quad\text{and}\quad t=1,

还要指定每个中间时间

0<t<10<t<1

时,样本应该服从什么分布。

这个随时间变化的一族分布称为 probability path。它可以被理解成“分布空间中的一条轨迹”:起点是噪声分布,终点是数据分布。

Probability path images

图中每一列都可以理解成某个时间 tt 下的样本状态。左边接近纯噪声,越往右数字结构越明显,最后得到清晰的数据样本。这里展示的是一条从噪声到数据的路径,而 Flow Matching 接下来要做的,就是学习一个速度场,让 ODE 的样本也沿着这样的路径移动。

为了更精确地定义这条路径,先固定一个数据点

zRd.z \in \mathbb R^d.

δz\delta_z 表示 Dirac delta distribution。它是最简单的一种“分布”:从 δz\delta_z 采样,永远只会得到 zz。因此可以把它理解成“所有概率质量都集中在数据点 zz 上”。

现在定义 conditional probability path:

pt(xz),p_t(x\mid z),

它是一族关于 xx 的分布,并且满足两个边界条件:

p0(z)=pinit,p_0(\cdot \mid z) = p_{\text{init}}, p1(z)=δz.p_1(\cdot \mid z) = \delta_z.

这两个条件的意思是:

所以 conditional probability path 描述的是:如何把一个简单初始分布逐渐变成某一个固定数据点。

这里的 “conditional” 很重要。它不是直接从噪声分布变成整个 pdatap_{\text{data}},而是先问:如果目标是某个具体样本 zz,那么从噪声到这个 zz 的中间分布可以怎样安排?

这个构造看起来像是在“记住训练样本”,但它只是一个中间工具。下一步会把所有 zpdataz \sim p_{\text{data}} 的 conditional paths 汇总起来,得到真正从噪声分布到数据分布的 marginal probability path。

这一页可以先记住一句话:

Flow Matching 先人为指定一条从噪声到数据的 probability path;训练速度场之前,先规定每个时间点的分布应该是什么。

Conditional path、marginal path 与 Gaussian path

固定一个数据点 zz 时,conditional probability path

pt(xz)p_t(x\mid z)

描述的是从噪声分布到单个点 zz 的路径:

p0(z)=pinit,p1(z)=δz.p_0(\cdot\mid z)=p_{\text{init}}, \qquad p_1(\cdot\mid z)=\delta_z.

但生成模型真正要学习的不是“到某一个固定点 zz 的路径”,而是“到整个数据分布 pdatap_{\text{data}} 的路径”。这一步通过把 zz 也看成随机变量完成:

zpdata,xpt(z).z \sim p_{\text{data}}, \qquad x \sim p_t(\cdot\mid z).

由这个两步采样过程得到的 xx 的分布,记作 ptp_t

zpdata,xpt(z)xpt.z \sim p_{\text{data}},\quad x \sim p_t(\cdot\mid z) \quad \Longrightarrow \quad x \sim p_t.

如果写成密度,就是

pt(x)=pt(xz)pdata(z)dz.p_t(x)=\int p_t(x\mid z)p_{\text{data}}(z)\,dz.

这里有一个细节很重要:我们可以从 ptp_t 中采样,但通常不能计算 pt(x)p_t(x) 的具体数值。原因是上面的积分需要对所有可能的数据点 zz 做积分,而真实的 pdatap_{\text{data}} 本来就是未知的。

因此:

sampling from ptp_t is easy; evaluating pt(x)p_t(x) is hard.

采样很简单,因为只需要:

  1. 从数据集中取一个样本 zz
  2. 从 conditional path pt(z)p_t(\cdot\mid z) 中采一个 xx

但计算密度值

pt(x)p_t(x)

需要完成不可行的积分。

Conditional and marginal path

图 5 正好展示了这个区别。上排是 conditional path:目标数据点 zz 固定,所以分布从高斯噪声逐渐收缩到一个点。下排是 marginal path:每次先随机选一个数据点 zpdataz\sim p_{\text{data}},再沿着对应的 conditional path 采样;所有这些 conditional paths 混合起来,就形成了从噪声分布到整个数据分布的路径。

由于每个 conditional path 都满足

p0(z)=pinit,p1(z)=δz,p_0(\cdot\mid z)=p_{\text{init}}, \qquad p_1(\cdot\mid z)=\delta_z,

因此 marginal path 自动满足

p0=pinit,p1=pdata.p_0=p_{\text{init}}, \qquad p_1=p_{\text{data}}.

这正是生成模型需要的分布级别路径。

Gaussian conditional probability path

最重要的一类 probability path 是 Gaussian path。它也是许多大规模生成模型中使用的基本形式。

先选择两个随时间变化的 scheduler:

αt, βt.\alpha_t,\ \beta_t.

它们满足边界条件:

α0=0,α1=1,\alpha_0=0,\qquad \alpha_1=1, β0=1,β1=0.\beta_0=1,\qquad \beta_1=0.

也就是说,αt\alpha_t 从 0 增加到 1,βt\beta_t 从 1 减小到 0。然后定义

pt(z)=N(αtz,βt2Id).p_t(\cdot\mid z) = \mathcal N(\alpha_t z,\beta_t^2 I_d).

这个公式的含义很清楚:在时间 tt,样本分布是一个高斯分布,它的均值是 αtz\alpha_t z,方差尺度是 βt2\beta_t^2

t=0t=0 时,

p0(z)=N(α0z,β02Id)=N(0,Id)=pinit.p_0(\cdot\mid z) = \mathcal N(\alpha_0z,\beta_0^2I_d) = \mathcal N(0,I_d) =p_{\text{init}}.

t=1t=1 时,

p1(z)=N(α1z,β12Id)=N(z,0)=δz.p_1(\cdot\mid z) = \mathcal N(\alpha_1z,\beta_1^2I_d) = \mathcal N(z,0) = \delta_z.

所以 Gaussian path 正好满足 conditional probability path 的要求。

从这个分布采样也非常简单。令

ϵN(0,Id),\epsilon \sim \mathcal N(0,I_d),

xt=αtz+βtϵx_t=\alpha_t z+\beta_t\epsilon

满足

xtpt(z).x_t \sim p_t(\cdot\mid z).

这个公式以后会反复出现。它是 diffusion 和 flow matching 训练中最基本的“加噪”形式:用 αt\alpha_t 保留数据,用 βt\beta_t 注入噪声。

一个最简单的选择是

αt=t,βt=1t.\alpha_t=t, \qquad \beta_t=1-t.

于是

xt=tz+(1t)ϵ.x_t=t z+(1-t)\epsilon.

这时 xtx_t 就是数据 zz 和噪声 ϵ\epsilon 的线性插值:t=0t=0 时是纯噪声,t=1t=1 时是纯数据。

这一页的重点可以概括为:

Conditional path 连接噪声和单个数据点;marginal path 把所有 conditional paths 按 zpdataz\sim p_{\text{data}} 混合起来,因此连接噪声分布和整个数据分布。

3.2 Conditional and Marginal Vector Fields

Gaussian path 给出了一个非常方便的采样公式:

zpdata,ϵN(0,Id),z \sim p_{\text{data}}, \qquad \epsilon \sim \mathcal N(0,I_d), xt=αtz+βtϵ.x_t=\alpha_t z+\beta_t\epsilon.

于是

xtpt.x_t \sim p_t.

这里 ptp_t 是 marginal probability path。直观上,tt 越小,βt\beta_t 越大,样本里噪声成分越多;tt 越接近 1,αt\alpha_t 越大,样本越接近真实数据。

到这里,probability path 只规定了一个愿望:

Xtpt.X_t \sim p_t.

也就是说,我们希望 ODE 的粒子在每个时间 tt 的分布等于 ptp_t。但是仅仅指定分布路径还不够,因为生成时真正要模拟的是 ODE:

ddtXt=ut(Xt).\frac{d}{dt}X_t = u_t(X_t).

因此新的问题变成:

能不能找到一个速度场 utu_t,使得沿着这个 ODE 运动的样本,在每个时间 tt 都服从我们指定的 ptp_t

Flow Matching 的核心就是构造这样的速度场。

先固定一个数据点 zz。如果存在一个 conditional vector field

uttarget(xz),u_t^{\text{target}}(x\mid z),

使得

X0pinit,ddtXt=uttarget(Xtz)X_0\sim p_{\text{init}}, \qquad \frac{d}{dt}X_t = u_t^{\text{target}}(X_t\mid z)

能够推出

Xtpt(z),0t1,X_t \sim p_t(\cdot\mid z), \qquad 0\le t\le 1,

那么这个 uttarget(xz)u_t^{\text{target}}(x\mid z) 就是 conditional vector field。它的作用是把初始噪声分布按照指定的 conditional path 推到某个固定的数据点 zz

乍看起来,这个对象似乎没有什么生成意义。因为如果 zz 固定,那么所有终点都会塌缩到

X1=z.X_1=z.

这只是“生成已知样本 zz”,不是从 pdatap_{\text{data}} 中生成新样本。但它的价值在于:conditional vector field 是构造 marginal vector field 的 building block。

Theorem 9: Marginalization Trick

如果已经有 conditional vector field

uttarget(xz),u_t^{\text{target}}(x\mid z),

那么可以定义 marginal vector field:

uttarget(x)=uttarget(xz)pt(xz)pdata(z)pt(x)dz.u_t^{\text{target}}(x) = \int u_t^{\text{target}}(x\mid z) \frac{p_t(x\mid z)p_{\text{data}}(z)}{p_t(x)} \,dz.

这个公式看起来复杂,但它其实就是一个条件期望:

uttarget(x)=E[uttarget(xz)xt=x].u_t^{\text{target}}(x) = \mathbb E\left[ u_t^{\text{target}}(x\mid z) \mid x_t=x \right].

其中

pt(xz)pdata(z)pt(x)\frac{p_t(x\mid z)p_{\text{data}}(z)}{p_t(x)}

就是在给定当前 noisy sample xx 的情况下,数据点 zz 的 posterior 权重。换句话说:如果当前看到的是 xx,那么不同的 zz 都可能是它背后的 clean data,公式会按照这些可能性给对应的 conditional velocity 加权平均。

定理说,用这个 marginal vector field 模拟 ODE:

X0pinit,ddtXt=uttarget(Xt),X_0\sim p_{\text{init}}, \qquad \frac{d}{dt}X_t = u_t^{\text{target}}(X_t),

就会得到

Xtpt,0t1.X_t\sim p_t, \qquad 0\le t\le 1.

特别地,

X1pdata.X_1\sim p_{\text{data}}.

所以 uttarget(x)u_t^{\text{target}}(x) 才是真正意义上的生成速度场:它把噪声分布推到整个数据分布。

Vector fields path

图 6 展示了这个定理的含义。上排固定一个数据点 zz,conditional vector field 让样本沿着 conditional path 走,最后收缩到这个点。下排把所有 zz 的影响混合起来,marginal vector field 让样本分布沿着 marginal path 演化,最后形成整个数据分布。

这一页非常关键,因为它解释了 Flow Matching 的基本策略:

  1. 先构造容易写出来的 conditional vector field;
  2. 理论上存在一个 posterior average 后的 marginal vector field;
  3. 真正生成时需要的是 marginal vector field;
  4. 后面训练神经网络时,会用 conditional vector field 作为可计算 target,却学到 marginal vector field。

最后一句是之后 Theorem 12 的核心,也是 Flow Matching 为什么能训练的根本原因。

Gaussian path 的 conditional velocity

现在要把 Gaussian conditional path 的速度场算出来。我们已经定义了:

pt(z)=N(αtz,βt2Id).p_t(\cdot\mid z) = \mathcal N(\alpha_t z,\beta_t^2 I_d).

等价地,从这个 path 采样可以写成

Xt=αtz+βtX0,X0N(0,Id).X_t = \alpha_t z+\beta_t X_0, \qquad X_0\sim \mathcal N(0,I_d).

这已经暗示了一个自然的 flow:

ψttarget(xz)=αtz+βtx.\psi_t^{\text{target}}(x\mid z) = \alpha_t z+\beta_t x.

如果初始点是 xx,那么时间 tt 的位置就是 αtz+βtx\alpha_t z+\beta_t x。当 xx 本身来自标准高斯时,

Xt=ψttarget(X0z)=αtz+βtX0N(αtz,βt2Id).X_t = \psi_t^{\text{target}}(X_0\mid z) = \alpha_t z+\beta_t X_0 \sim \mathcal N(\alpha_t z,\beta_t^2I_d).

这就说明这个 flow 的样本分布正好等于我们指定的 conditional probability path。

接下来从 flow 中提取 vector field。按定义,flow 和 vector field 满足:

ddtψttarget(xz)=uttarget(ψttarget(xz)z).\frac{d}{dt}\psi_t^{\text{target}}(x\mid z) = u_t^{\text{target}}(\psi_t^{\text{target}}(x\mid z)\mid z).

左边直接对时间求导:

ddt(αtz+βtx)=α˙tz+β˙tx.\frac{d}{dt} \left(\alpha_t z+\beta_t x\right) = \dot\alpha_t z+\dot\beta_t x.

因此

uttarget(αtz+βtxz)=α˙tz+β˙tx.u_t^{\text{target}}(\alpha_t z+\beta_t x\mid z) = \dot\alpha_t z+\dot\beta_t x.

现在要把右边的初始噪声变量 xx 换成当前位置。设当前位置为

y=αtz+βtx.y=\alpha_t z+\beta_t x.

x=yαtzβt.x=\frac{y-\alpha_t z}{\beta_t}.

代回去:

uttarget(yz)=α˙tz+β˙tyαtzβt.u_t^{\text{target}}(y\mid z) = \dot\alpha_t z +\dot\beta_t \frac{y-\alpha_t z}{\beta_t}.

整理得到

uttarget(yz)=(α˙tβ˙tβtαt)z+β˙tβty.u_t^{\text{target}}(y\mid z) = \left( \dot\alpha_t - \frac{\dot\beta_t}{\beta_t}\alpha_t \right)z + \frac{\dot\beta_t}{\beta_t}y.

把当前位置重新记成 xx,就得到:

uttarget(xz)=(α˙tβ˙tβtαt)z+β˙tβtx.u_t^{\text{target}}(x\mid z) = \left( \dot\alpha_t - \frac{\dot\beta_t}{\beta_t}\alpha_t \right)z + \frac{\dot\beta_t}{\beta_t}x.

这个公式可以这样理解:速度由两部分组成。一部分依赖目标数据点 zz,负责把样本往目标方向拉;另一部分依赖当前位置 xx,负责处理当前噪声尺度随时间缩小的影响。

如果使用最简单的 CondOT path:

αt=t,βt=1t,\alpha_t=t, \qquad \beta_t=1-t,

那么

xt=tz+(1t)ϵ.x_t=tz+(1-t)\epsilon.

把这个选择代入一般公式,先得到的是“当前位置 xx 处的速度场”:

α˙t=1,β˙t=1.\dot\alpha_t=1, \qquad \dot\beta_t=-1.

因此

uttarget(xz)=(111tt)z11tx=zx1t.u_t^{\text{target}}(x\mid z) = \left( 1-\frac{-1}{1-t}t \right)z -\frac{1}{1-t}x = \frac{z-x}{1-t}.

这个式子里的 xx 是任意当前位置。若当前位置正好位于 Gaussian path 上,

x=xt=tz+(1t)ϵ,x=x_t=tz+(1-t)\epsilon,

uttarget(xtz)=z(tz+(1t)ϵ)1t=zϵ.u_t^{\text{target}}(x_t\mid z) = \frac{z-\left(tz+(1-t)\epsilon\right)}{1-t} = z-\epsilon.

这和直接对路径

xt=tz+(1t)ϵx_t=tz+(1-t)\epsilon

求导得到的结果一致:

ddtxt=zϵ.\frac{d}{dt}x_t=z-\epsilon.

所以这个特殊情况下,conditional vector field 在训练样本 xtx_t 处的 target 会变得非常简单:

uttarget(xtz)=zϵ.u_t^{\text{target}}(x_t\mid z)=z-\epsilon.

这也是后面 Algorithm 3 里 Flow Matching loss 的核心 target。

这一段的意义在于:对于 Gaussian path,我们不仅能指定中间分布,还能显式写出让 ODE 跟随这条 path 的 conditional velocity。后面训练神经网络时,就可以把这个可计算的 velocity 当作监督信号。

Continuity equation 与 marginalization trick 的证明

前面已经写出了 marginalization trick:

uttarget(x)=uttarget(xz)pt(xz)pdata(z)pt(x)dz.u_t^{\text{target}}(x) = \int u_t^{\text{target}}(x\mid z) \frac{p_t(x\mid z)p_{\text{data}}(z)}{p_t(x)} \,dz.

现在需要解释为什么这个速度场真的能让 ODE 的分布沿着 marginal path ptp_t 走。这里用到的工具是 continuity equation。

考虑一个 flow model:

X0pinit=p0,ddtXt=uttarget(Xt).X_0\sim p_{\text{init}}=p_0, \qquad \frac{d}{dt}X_t=u_t^{\text{target}}(X_t).

如果这个 ODE 在每个时间 tt 的样本分布都是 ptp_t,也就是

Xtpt,X_t\sim p_t,

那么 ptp_t 和速度场之间必须满足

tpt(x)=div(ptuttarget)(x).\partial_t p_t(x) = -\operatorname{div}\left(p_tu_t^{\text{target}}\right)(x).

这就是 continuity equation。

这个公式可以从概率质量守恒来理解。左边

tpt(x)\partial_t p_t(x)

表示位置 xx 处的概率密度随时间怎么变化。右边的 divergence 描述概率质量的流出程度:

div(ptut)(x)\operatorname{div}(p_tu_t)(x)

可以理解成在位置 xx 附近,概率质量沿着速度场向外流走的净量。因此前面加负号:

div(ptut)(x)-\operatorname{div}(p_tu_t)(x)

表示净流入。某处密度增加,是因为流入大于流出;某处密度减少,是因为流出大于流入。

所以 continuity equation 说的就是:

概率密度的时间变化 = 概率质量的净流入。

这和流体力学里的质量守恒是同一个思想,只是这里流动的是概率质量。

有了这个方程,就可以证明 Theorem 9。marginal path 的密度是

pt(x)=pt(xz)pdata(z)dz.p_t(x) = \int p_t(x\mid z)p_{\text{data}}(z)\,dz.

对时间求导:

tpt(x)=tpt(xz)pdata(z)dz=tpt(xz)pdata(z)dz.\partial_t p_t(x) = \partial_t \int p_t(x\mid z)p_{\text{data}}(z)\,dz = \int \partial_t p_t(x\mid z)p_{\text{data}}(z)\,dz.

因为 conditional vector field 能实现 conditional path,所以每个 pt(xz)p_t(x\mid z) 都满足自己的 continuity equation:

tpt(xz)=div(pt(z)uttarget(z))(x).\partial_t p_t(x\mid z) = -\operatorname{div} \left( p_t(\cdot\mid z) u_t^{\text{target}}(\cdot\mid z) \right)(x).

代入上式:

tpt(x)=div(pt(z)uttarget(z))(x)pdata(z)dz.\partial_t p_t(x) = \int -\operatorname{div} \left( p_t(\cdot\mid z) u_t^{\text{target}}(\cdot\mid z) \right)(x) p_{\text{data}}(z)\,dz.

把 divergence 和积分交换,就得到

tpt(x)=div(pt(xz)uttarget(xz)pdata(z)dz).\partial_t p_t(x) = -\operatorname{div} \left( \int p_t(x\mid z) u_t^{\text{target}}(x\mid z) p_{\text{data}}(z) \,dz \right).

而 marginal vector field 的定义正是

uttarget(x)=uttarget(xz)pt(xz)pdata(z)pt(x)dz.u_t^{\text{target}}(x) = \int u_t^{\text{target}}(x\mid z) \frac{p_t(x\mid z)p_{\text{data}}(z)}{p_t(x)} \,dz.

两边乘上 pt(x)p_t(x)

pt(x)uttarget(x)=pt(xz)uttarget(xz)pdata(z)dz.p_t(x)u_t^{\text{target}}(x) = \int p_t(x\mid z) u_t^{\text{target}}(x\mid z) p_{\text{data}}(z) \,dz.

因此

tpt(x)=div(ptuttarget)(x).\partial_t p_t(x) = -\operatorname{div} \left( p_tu_t^{\text{target}} \right)(x).

这说明 marginal vector field 满足 continuity equation,所以它确实会让 ODE 的样本分布沿着 marginal path ptp_t 演化。

这个证明的核心并不是复杂计算,而是一个很干净的结构:

  1. 每个 conditional path 都有自己的速度场;
  2. 每个 conditional path 都满足 continuity equation;
  3. marginal path 是 conditional paths 的混合;
  4. conditional velocities 按 posterior 权重平均后,得到的 marginal velocity 正好让混合分布也满足 continuity equation。

到这里,Flow Matching 的理论目标已经清楚了:如果能学到

uttarget(x),u_t^{\text{target}}(x),

那么从噪声出发模拟 ODE,就能得到数据分布。下一步的问题是训练:这个 marginal vector field 里面有不可计算的积分,神经网络怎样学到它?

3.3 Learning the Marginal Vector Field

现在目标已经明确:训练一个神经网络

utθ(x)u_t^\theta(x)

去逼近真正的 marginal vector field

uttarget(x).u_t^{\text{target}}(x).

如果能够做到

utθ(x)uttarget(x),u_t^\theta(x)\approx u_t^{\text{target}}(x),

那么模拟

dXt=utθ(Xt)dt,X0pinitdX_t=u_t^\theta(X_t)\,dt, \qquad X_0\sim p_{\text{init}}

就会得到近似满足

X1pdataX_1\sim p_{\text{data}}

的生成模型。

最直接的想法是用均方误差:

LFM(θ)=EtUnif, xpt[utθ(x)uttarget(x)2].L_{\text{FM}}(\theta) = \mathbb E_{t\sim \operatorname{Unif},\ x\sim p_t} \left[ \left\| u_t^\theta(x)-u_t^{\text{target}}(x) \right\|^2 \right].

这个 loss 的含义很自然:随机取一个时间 tt,再从对应的 marginal path ptp_t 里取一个点 xx,让网络输出的速度接近真正的 marginal velocity。

由于从 ptp_t 采样可以通过 conditional path 完成,

zpdata,xpt(z),z\sim p_{\text{data}}, \qquad x\sim p_t(\cdot\mid z),

所以同一个 loss 也可以写成

LFM(θ)=EtUnif, zpdata, xpt(z)[utθ(x)uttarget(x)2].L_{\text{FM}}(\theta) = \mathbb E_{t\sim \operatorname{Unif},\ z\sim p_{\text{data}},\ x\sim p_t(\cdot\mid z)} \left[ \left\| u_t^\theta(x)-u_t^{\text{target}}(x) \right\|^2 \right].

问题在于,这个 loss 仍然不可计算。困难不在采样 xx,而在 target:

uttarget(x)=uttarget(xz)pt(xz)pdata(z)pt(x)dz.u_t^{\text{target}}(x) = \int u_t^{\text{target}}(x\mid z) \frac{p_t(x\mid z)p_{\text{data}}(z)}{p_t(x)} \,dz.

这个 target 是对所有可能数据点 zz 的 posterior average,里面有不可计算的积分,也包含未知的 pt(x)p_t(x)

Flow Matching 的关键转折是:不用直接回归 marginal target,而是回归 conditional target:

LCFM(θ)=EtUnif, zpdata, xpt(z)[utθ(x)uttarget(xz)2].L_{\text{CFM}}(\theta) = \mathbb E_{t\sim \operatorname{Unif},\ z\sim p_{\text{data}},\ x\sim p_t(\cdot\mid z)} \left[ \left\| u_t^\theta(x)-u_t^{\text{target}}(x\mid z) \right\|^2 \right].

这个 loss 是可计算的,因为 uttarget(xz)u_t^{\text{target}}(x\mid z) 对我们选择的 probability path 通常有解析公式。比如 Gaussian path 中,上一节已经算出:

uttarget(xz)=(α˙tβ˙tβtαt)z+β˙tβtx.u_t^{\text{target}}(x\mid z) = \left( \dot\alpha_t - \frac{\dot\beta_t}{\beta_t}\alpha_t \right)z + \frac{\dot\beta_t}{\beta_t}x.

现在看起来有一个矛盾:我们真正想学的是 marginal vector field

uttarget(x),u_t^{\text{target}}(x),

但训练时却让网络拟合 conditional vector field

uttarget(xz).u_t^{\text{target}}(x\mid z).

Theorem 12 正是 Flow Matching 的核心。它说:

LFM(θ)=LCFM(θ)+C,L_{\text{FM}}(\theta) = L_{\text{CFM}}(\theta)+C,

其中 CCθ\theta 无关。因此

θLFM(θ)=θLCFM(θ).\nabla_\theta L_{\text{FM}}(\theta) = \nabla_\theta L_{\text{CFM}}(\theta).

也就是说,优化可计算的 LCFML_{\text{CFM}},等价于优化不可计算的 LFML_{\text{FM}}。如果模型表达能力足够强,最优的网络会学到 marginal vector field:

utθ(x)=uttarget(x).u_t^\theta(x)=u_t^{\text{target}}(x).

这个结论的直觉其实就是均方误差的条件期望性质。对于固定的 noisy point xx,conditional target

uttarget(xz)u_t^{\text{target}}(x\mid z)

会随着背后的 zz 改变。均方误差下,最好的预测不是某一个具体 zz 对应的速度,而是这些可能速度的条件平均:

E[uttarget(xz)x].\mathbb E[ u_t^{\text{target}}(x\mid z) \mid x ].

而这个条件平均正是 marginal vector field:

uttarget(x).u_t^{\text{target}}(x).

Theorem 12 的证明思路

证明不需要神秘技巧,只是展开平方项。先从 marginal loss 开始:

LFM(θ)=Et,x[utθ(x)uttarget(x)2].L_{\text{FM}}(\theta) = \mathbb E_{t,x} \left[ \left\| u_t^\theta(x)-u_t^{\text{target}}(x) \right\|^2 \right].

ab2=a22ab+b2\|a-b\|^2=\|a\|^2-2a^\top b+\|b\|^2

展开:

LFM(θ)=Et,x[utθ(x)2]2Et,x[utθ(x)uttarget(x)]+C1.L_{\text{FM}}(\theta) = \mathbb E_{t,x} \left[ \|u_t^\theta(x)\|^2 \right] -2 \mathbb E_{t,x} \left[ u_t^\theta(x)^\top u_t^{\text{target}}(x) \right] + C_1.

其中

C1=Et,x[uttarget(x)2]C_1= \mathbb E_{t,x} \left[ \|u_t^{\text{target}}(x)\|^2 \right]

不依赖 θ\theta,所以对训练梯度来说只是常数。

第一项

Et,xpt[utθ(x)2]\mathbb E_{t,x\sim p_t} \left[ \|u_t^\theta(x)\|^2 \right]

可以直接改写成 conditional sampling:

Et,z,xpt(z)[utθ(x)2].\mathbb E_{t,z,x\sim p_t(\cdot\mid z)} \left[ \|u_t^\theta(x)\|^2 \right].

关键是第二个内积项:

Et,xpt[utθ(x)uttarget(x)].\mathbb E_{t,x\sim p_t} \left[ u_t^\theta(x)^\top u_t^{\text{target}}(x) \right].

把 marginal vector field 的定义代入:

uttarget(x)=uttarget(xz)pt(xz)pdata(z)pt(x)dz.u_t^{\text{target}}(x) = \int u_t^{\text{target}}(x\mid z) \frac{p_t(x\mid z)p_{\text{data}}(z)}{p_t(x)} \,dz.

于是

Et,xpt[utθ(x)uttarget(x)]=Et,z,xpt(z)[utθ(x)uttarget(xz)].\mathbb E_{t,x\sim p_t} \left[ u_t^\theta(x)^\top u_t^{\text{target}}(x) \right] = \mathbb E_{t,z,x\sim p_t(\cdot\mid z)} \left[ u_t^\theta(x)^\top u_t^{\text{target}}(x\mid z) \right].

这一步是证明里最重要的一步:marginal target 出现在内积里时,可以换成 conditional target。

代回展开式后,再补上并减去

uttarget(xz)2,\left\|u_t^{\text{target}}(x\mid z)\right\|^2,

就得到

LFM(θ)=LCFM(θ)+C,L_{\text{FM}}(\theta) = L_{\text{CFM}}(\theta)+C,

其中 CC 不依赖 θ\theta

这说明 Flow Matching 的训练可以完全避免显式计算 marginal vector field。训练时只需要:

  1. 采样时间 tt
  2. 采样数据 zz
  3. pt(z)p_t(\cdot\mid z) 采样 noisy point xx
  4. 让网络预测 conditional velocity uttarget(xz)u_t^{\text{target}}(x\mid z)

整个训练过程不需要模拟 ODE。也就是说,训练是 simulation-free 的。ODE 只在训练完成之后、生成样本时才需要模拟。

这就是 Flow Matching 简洁且适合大规模训练的原因:它把生成模型的训练变成了一个普通的监督回归问题。

Theorem 12 的更细证明

先固定一个时间 tt。为了不让符号太重,暂时把 tt 省略掉,最后再把对 tt 的期望加回来。

训练时的采样过程是:

zpdata,xpt(z).z\sim p_{\text{data}}, \qquad x\sim p_t(\cdot\mid z).

这实际上定义了一个关于 (x,z)(x,z) 的联合分布:

qt(x,z)=pt(xz)pdata(z).q_t(x,z)=p_t(x\mid z)p_{\text{data}}(z).

它的 xx-marginal 是:

pt(x)=pt(xz)pdata(z)dz.p_t(x)=\int p_t(x\mid z)p_{\text{data}}(z)\,dz.

因此 posterior 是:

qt(zx)=pt(xz)pdata(z)pt(x).q_t(z\mid x) = \frac{p_t(x\mid z)p_{\text{data}}(z)}{p_t(x)}.

现在把 conditional target 记成

vt(x,z)=uttarget(xz).v_t(x,z)=u_t^{\text{target}}(x\mid z).

marginal target 就是它在给定 xx 后的条件平均:

vˉt(x)=uttarget(x)=E[vt(x,z)x].\bar v_t(x) = u_t^{\text{target}}(x) = \mathbb E[v_t(x,z)\mid x].

也就是

vˉt(x)=vt(x,z)qt(zx)dz.\bar v_t(x) = \int v_t(x,z)q_t(z\mid x)\,dz.

神经网络输出记成

Ut(x)=utθ(x).U_t(x)=u_t^\theta(x).

于是两个 loss 可以写成:

LFM=Et,x[Ut(x)vˉt(x)2],L_{\text{FM}} = \mathbb E_{t,x} \left[ \|U_t(x)-\bar v_t(x)\|^2 \right], LCFM=Et,x,z[Ut(x)vt(x,z)2].L_{\text{CFM}} = \mathbb E_{t,x,z} \left[ \|U_t(x)-v_t(x,z)\|^2 \right].

注意这两个期望里的 xx 的 marginal 都是 pt(x)p_t(x)。区别只在 target:一个用条件平均 vˉt(x)\bar v_t(x),另一个用具体的 vt(x,z)v_t(x,z)

现在展开 LFML_{\text{FM}}

LFM=Et,x[Ut(x)22Ut(x)vˉt(x)+vˉt(x)2].L_{\text{FM}} = \mathbb E_{t,x} \left[ \|U_t(x)\|^2 -2U_t(x)^\top \bar v_t(x) +\|\bar v_t(x)\|^2 \right].

也就是

LFM=Et,x[Ut(x)2]2Et,x[Ut(x)vˉt(x)]+Et,x[vˉt(x)2].L_{\text{FM}} = \mathbb E_{t,x}[\|U_t(x)\|^2] -2\mathbb E_{t,x}[U_t(x)^\top \bar v_t(x)] +\mathbb E_{t,x}[\|\bar v_t(x)\|^2].

再展开 LCFML_{\text{CFM}}

LCFM=Et,x,z[Ut(x)22Ut(x)vt(x,z)+vt(x,z)2].L_{\text{CFM}} = \mathbb E_{t,x,z} \left[ \|U_t(x)\|^2 -2U_t(x)^\top v_t(x,z) +\|v_t(x,z)\|^2 \right].

也就是

LCFM=Et,x,z[Ut(x)2]2Et,x,z[Ut(x)vt(x,z)]+Et,x,z[vt(x,z)2].L_{\text{CFM}} = \mathbb E_{t,x,z}[\|U_t(x)\|^2] -2\mathbb E_{t,x,z}[U_t(x)^\top v_t(x,z)] +\mathbb E_{t,x,z}[\|v_t(x,z)\|^2].

现在逐项比较。

第一项相同,因为 Ut(x)U_t(x) 只依赖 xx,而两边的 xx-marginal 都是 pt(x)p_t(x)

Et,x,z[Ut(x)2]=Et,x[Ut(x)2].\mathbb E_{t,x,z}[\|U_t(x)\|^2] = \mathbb E_{t,x}[\|U_t(x)\|^2].

第二项也相同。先看 LCFML_{\text{CFM}} 里的交叉项:

Et,x,z[Ut(x)vt(x,z)].\mathbb E_{t,x,z}[U_t(x)^\top v_t(x,z)].

对给定的 t,xt,xUt(x)U_t(x) 已经固定,不再依赖 zz,所以可以先对 zxz\mid x 取条件期望:

Ezx[Ut(x)vt(x,z)]=Ut(x)Ezx[vt(x,z)].\mathbb E_{z\mid x} \left[ U_t(x)^\top v_t(x,z) \right] = U_t(x)^\top \mathbb E_{z\mid x}[v_t(x,z)].

但根据 marginal target 的定义:

Ezx[vt(x,z)]=vˉt(x).\mathbb E_{z\mid x}[v_t(x,z)] = \bar v_t(x).

所以

Et,x,z[Ut(x)vt(x,z)]=Et,x[Ut(x)vˉt(x)].\mathbb E_{t,x,z}[U_t(x)^\top v_t(x,z)] = \mathbb E_{t,x}[U_t(x)^\top \bar v_t(x)].

这就是证明里最关键的一步:conditional target 在和网络输出做内积后,平均起来等于 marginal target。

第三项不相同:

Et,x[vˉt(x)2]\mathbb E_{t,x}[\|\bar v_t(x)\|^2]

Et,x,z[vt(x,z)2]\mathbb E_{t,x,z}[\|v_t(x,z)\|^2]

一般不是同一个数。但是它们都不含 θ\theta,因为它们只由我们构造的 target 决定,与神经网络参数无关。

因此两个 loss 的差只来自这两个不依赖 θ\theta 的项:

LFM(θ)LCFM(θ)=Et,x[vˉt(x)2]Et,x,z[vt(x,z)2].L_{\text{FM}}(\theta)-L_{\text{CFM}}(\theta) = \mathbb E_{t,x}[\|\bar v_t(x)\|^2] - \mathbb E_{t,x,z}[\|v_t(x,z)\|^2].

右边是一个常数,记成 CC

LFM(θ)=LCFM(θ)+C.L_{\text{FM}}(\theta) = L_{\text{CFM}}(\theta)+C.

所以

θLFM(θ)=θLCFM(θ).\nabla_\theta L_{\text{FM}}(\theta) = \nabla_\theta L_{\text{CFM}}(\theta).

这就是 Theorem 12。

也可以用一句更短的统计语言理解:在均方误差下,如果 target 是随机变量 vt(x,z)v_t(x,z),而模型只能看到 xx,那么最优预测就是条件期望

E[vt(x,z)x],\mathbb E[v_t(x,z)\mid x],

也就是 marginal vector field。

Algorithm 3 与 Gaussian Flow Matching loss

Theorem 12 之后,训练流程就变得非常直接。以最常用的 Gaussian CondOT path 为例:

αt=t,βt=1t.\alpha_t=t, \qquad \beta_t=1-t.

这时 conditional path 是

pt(xz)=N(tz,(1t)2Id).p_t(x\mid z)=\mathcal N(tz,(1-t)^2I_d).

采样形式为

ϵN(0,Id),x=tz+(1t)ϵ.\epsilon\sim \mathcal N(0,I_d), \qquad x=tz+(1-t)\epsilon.

对应的 conditional velocity target 是

uttarget(xz)=zϵ.u_t^{\text{target}}(x\mid z)=z-\epsilon.

因此训练就是:

Algorithm 3

每个 mini-batch 中:

  1. 从数据集中采样一个真实样本
zpdata.z\sim p_{\text{data}}.
  1. 采样一个随机时间
tUnif[0,1].t\sim \operatorname{Unif}[0,1].
  1. 采样标准高斯噪声
ϵN(0,Id).\epsilon\sim \mathcal N(0,I_d).
  1. 构造中间点
x=tz+(1t)ϵ.x=tz+(1-t)\epsilon.
  1. 让网络预测这个点处的 velocity:
utθ(x).u_t^\theta(x).
  1. 用均方误差回归 target:
L(θ)=utθ(x)(zϵ)2.L(\theta) = \left\| u_t^\theta(x)-(z-\epsilon) \right\|^2.

这就是 CondOT path 下的 Flow Matching loss。

这个训练循环有一个很重要的特点:训练时完全不需要模拟 ODE。每一步的训练样本 xx 都是直接由

x=tz+(1t)ϵx=tz+(1-t)\epsilon

构造出来的。ODE solver 只在训练完成后、真正采样生成时才使用。

因此 Flow Matching 是 simulation-free training。它不是一边训练一边从 t=0t=0 rollout 到 t=1t=1,而是随机抽一个时间 tt,直接构造那个时间的 noisy sample,然后做监督回归。

Gaussian path 下的一般 Flow Matching loss

CondOT 是最简单的 path,但 Gaussian probability path 可以更一般。设

pt(xz)=N(αtz,βt2Id).p_t(x\mid z) = \mathcal N(\alpha_tz,\beta_t^2I_d).

采样形式是

ϵN(0,Id),xt=αtz+βtϵ.\epsilon\sim\mathcal N(0,I_d), \qquad x_t=\alpha_tz+\beta_t\epsilon.

前面已经推导过,这个 path 的 conditional vector field 是

uttarget(xz)=(α˙tβ˙tβtαt)z+β˙tβtx.u_t^{\text{target}}(x\mid z) = \left( \dot\alpha_t - \frac{\dot\beta_t}{\beta_t}\alpha_t \right)z + \frac{\dot\beta_t}{\beta_t}x.

如果把训练点

xt=αtz+βtϵx_t=\alpha_tz+\beta_t\epsilon

代入这个速度场,会得到一个更简单的 target:

uttarget(xtz)=α˙tz+β˙tϵ.u_t^{\text{target}}(x_t\mid z) = \dot\alpha_t z+\dot\beta_t\epsilon.

因为沿着 path 本身求导就是

ddtxt=ddt(αtz+βtϵ)=α˙tz+β˙tϵ.\frac{d}{dt}x_t = \frac{d}{dt}(\alpha_tz+\beta_t\epsilon) = \dot\alpha_t z+\dot\beta_t\epsilon.

于是一般 Gaussian path 的 conditional flow matching loss 可以写成

LCFM(θ)=Et,z,ϵ[utθ(αtz+βtϵ)(α˙tz+β˙tϵ)2],L_{\text{CFM}}(\theta) = \mathbb E_{t,z,\epsilon} \left[ \left\| u_t^\theta(\alpha_tz+\beta_t\epsilon) - (\dot\alpha_tz+\dot\beta_t\epsilon) \right\|^2 \right],

其中

tUnif[0,1],zpdata,ϵN(0,Id).t\sim\operatorname{Unif}[0,1], \qquad z\sim p_{\text{data}}, \qquad \epsilon\sim\mathcal N(0,I_d).

CondOT 只是把

αt=t,βt=1t\alpha_t=t, \qquad \beta_t=1-t

代进去。此时

α˙t=1,β˙t=1,\dot\alpha_t=1, \qquad \dot\beta_t=-1,

所以 target 变成

α˙tz+β˙tϵ=zϵ.\dot\alpha_tz+\dot\beta_t\epsilon = z-\epsilon.

因此 CondOT loss 是

LCFM(θ)=Et,z,ϵ[utθ(tz+(1t)ϵ)(zϵ)2].L_{\text{CFM}}(\theta) = \mathbb E_{t,z,\epsilon} \left[ \left\| u_t^\theta(tz+(1-t)\epsilon) - (z-\epsilon) \right\|^2 \right].

这个公式几乎就是 Flow Matching 最常见的训练形式。它只有三个随机量:数据 zz、噪声 ϵ\epsilon、时间 tt。然后构造 xtx_t,让网络预测从噪声指向数据的速度。

训练后的 ODE 与本节总结

Figure 7

Figure 7 用一个二维棋盘分布展示 Theorem 12 的效果。上排是 ground-truth marginal probability path,也就是我们希望的 ptp_t。下排是训练好的 Flow Matching model 通过模拟 ODE 得到的样本分布。

如果训练成功,那么下排应该和上排匹配。图里可以看到:

这张图说明:虽然训练时网络回归的是 conditional velocity,

uttarget(xz),u_t^{\text{target}}(x\mid z),

但最终模拟 ODE 时,网络学到的行为接近 marginal velocity,

uttarget(x).u_t^{\text{target}}(x).

这正是 Theorem 12 的实际效果。

Summary 14:Flow Matching 的完整结构

Flow Matching 到这里形成了一个完整闭环。

Summary 14

第一步,选择 conditional probability path:

pt(xz),p_t(x\mid z),

满足

p0(z)=pinit,p1(z)=δz.p_0(\cdot\mid z)=p_{\text{init}}, \qquad p_1(\cdot\mid z)=\delta_z.

它描述如何从噪声分布走到某个固定数据点 zz

第二步,找到对应的 conditional vector field:

uttarget(xz),u_t^{\text{target}}(x\mid z),

使得

X0pinit,dXt=uttarget(Xtz)dtX_0\sim p_{\text{init}}, \qquad dX_t=u_t^{\text{target}}(X_t\mid z)\,dt

推出

Xtpt(z).X_t\sim p_t(\cdot\mid z).

第三步,把 conditional vector field 做 posterior average,得到 marginal vector field:

uttarget(x)=uttarget(xz)pt(xz)pdata(z)pt(x)dz.u_t^{\text{target}}(x) = \int u_t^{\text{target}}(x\mid z) \frac{p_t(x\mid z)p_{\text{data}}(z)}{p_t(x)} \,dz.

这个速度场满足

X0pinit,dXt=uttarget(Xt)dtXtpt.X_0\sim p_{\text{init}}, \qquad dX_t=u_t^{\text{target}}(X_t)\,dt \quad\Longrightarrow\quad X_t\sim p_t.

因此终点满足

X1pdata.X_1\sim p_{\text{data}}.

第四步,训练神经网络。虽然 marginal vector field 不可直接计算,但 Theorem 12 说明,只需要最小化 conditional flow matching loss:

LCFM(θ)=Et,z,xpt(z)[utθ(x)uttarget(xz)2].L_{\text{CFM}}(\theta) = \mathbb E_{t,z,x\sim p_t(\cdot\mid z)} \left[ \left\| u_t^\theta(x)-u_t^{\text{target}}(x\mid z) \right\|^2 \right].

优化这个 loss 等价于学习 marginal vector field。

Gaussian path 的最终公式

最常用的例子是 Gaussian probability path:

Gaussian Flow Matching Formulas

它的 conditional path 是

pt(xz)=N(x;αtz,βt2Id).p_t(x\mid z) = \mathcal N(x;\alpha_tz,\beta_t^2I_d).

conditional vector field 是

uttarget(xz)=(α˙tβ˙tβtαt)z+β˙tβtx.u_t^{\text{target}}(x\mid z) = \left( \dot\alpha_t - \frac{\dot\beta_t}{\beta_t}\alpha_t \right)z + \frac{\dot\beta_t}{\beta_t}x.

训练 loss 可以写成

LCFM(θ)=Et,z,ϵ[utθ(αtz+βtϵ)(α˙tz+β˙tϵ)2].L_{\text{CFM}}(\theta) = \mathbb E_{t,z,\epsilon} \left[ \left\| u_t^\theta(\alpha_tz+\beta_t\epsilon) - (\dot\alpha_tz+\dot\beta_t\epsilon) \right\|^2 \right].

其中 scheduler 满足

α0=0,α1=1,β0=1,β1=0.\alpha_0=0,\quad \alpha_1=1, \qquad \beta_0=1,\quad \beta_1=0.

CondOT 是最简单的特例:

αt=t,βt=1t.\alpha_t=t, \qquad \beta_t=1-t.

于是

xt=tz+(1t)ϵ,uttarget(xtz)=zϵ.x_t=tz+(1-t)\epsilon, \qquad u_t^{\text{target}}(x_t\mid z)=z-\epsilon.

训练目标就是

utθ(tz+(1t)ϵ)(zϵ)2.\left\| u_t^\theta(tz+(1-t)\epsilon) - (z-\epsilon) \right\|^2.

到这里,Flow Matching 这一节的主线已经完成:

选择一条从噪声到数据的 probability path,写出 conditional velocity,用它训练网络;虽然训练 target 是 conditional 的,MSE 最优解却是 marginal velocity,因此训练好的 ODE 可以从噪声生成数据。


4. Score Functions and Score Matching:从速度场到 score

Flow Matching 中的核心对象是 vector field:

ut(x).u_t(x).

它告诉我们在时间 tt、位置 xx,样本应该朝哪个方向移动。Diffusion models 常常使用另一种语言:score function。

给定一个概率密度 q(x)q(x),它的 score function 定义为

logq(x).\nabla \log q(x).

这个对象的含义很直接:它指向让 log-density 增长最快的方向。也就是说,如果当前位置是 xx,那么沿着

logq(x)\nabla \log q(x)

走一小步,会最有效地进入 qq 认为更高概率的区域。

Score function

图 8 左边是一块概率密度,右边的箭头就是 score field。箭头一般会指向高密度区域;在低密度区域,score 会告诉你“往哪里走更像这个分布”。

这和生成模型的直觉很接近。生成时,样本一开始是噪声,处在数据分布的低概率区域。score 提供了一个方向:怎样把样本往更高概率、更像数据的地方推。

4.1 Conditional and Marginal Score Functions

回到 probability path。固定一个 clean data zz 时,有 conditional density

pt(xz).p_t(x\mid z).

混合所有可能的 zpdataz\sim p_{\text{data}} 后,有 marginal density

pt(x)=pt(xz)pdata(z)dz.p_t(x)=\int p_t(x\mid z)p_{\text{data}}(z)\,dz.

因此有两种 score:

xlogpt(xz)\nabla_x\log p_t(x\mid z)

xlogpt(x).\nabla_x\log p_t(x).

前者知道当前 noisy sample xx 是围绕哪个 clean data zz 生成的;后者不知道 zz,只知道当前看到的是 xx

和 marginal vector field 一样,marginal score 也可以写成 conditional score 的 posterior average:

logpt(x)=logpt(xz)pt(xz)pdata(z)pt(x)dz.\nabla\log p_t(x) = \int \nabla\log p_t(x\mid z) \frac{p_t(x\mid z)p_{\text{data}}(z)}{p_t(x)} \,dz.

其中

pt(xz)pdata(z)pt(x)\frac{p_t(x\mid z)p_{\text{data}}(z)}{p_t(x)}

就是 Bayes posterior:

pt(zx).p_t(z\mid x).

所以这个公式也可以写成更直观的形式:

logpt(x)=Ezx[logpt(xz)].\nabla\log p_t(x) = \mathbb E_{z\mid x} \left[ \nabla\log p_t(x\mid z) \right].

意思是:当前 noisy sample 是 xx,它可能来自很多不同的 clean data zz;每个 zz 都给出一个 conditional score,marginal score 是这些 conditional score 的 posterior average。

这个公式从 score 的定义直接推出。首先

logpt(x)=pt(x)pt(x).\nabla\log p_t(x) = \frac{\nabla p_t(x)}{p_t(x)}.

代入 marginal density:

logpt(x)=pt(xz)pdata(z)dzpt(x).\nabla\log p_t(x) = \frac{ \nabla\int p_t(x\mid z)p_{\text{data}}(z)\,dz }{ p_t(x) }.

把梯度放进积分:

logpt(x)=pt(xz)pdata(z)dzpt(x).\nabla\log p_t(x) = \frac{ \int \nabla p_t(x\mid z)p_{\text{data}}(z)\,dz }{ p_t(x) }.

再使用恒等式

pt(xz)=pt(xz)logpt(xz),\nabla p_t(x\mid z) = p_t(x\mid z)\nabla\log p_t(x\mid z),

就得到 posterior average 形式。

Gaussian path 的 score

对 Gaussian probability path,

pt(xz)=N(αtz,βt2Id).p_t(x\mid z)=\mathcal N(\alpha_tz,\beta_t^2I_d).

它的均值是 αtz\alpha_tz,协方差是 βt2Id\beta_t^2I_d。log-density 中和 xx 有关的部分为

logpt(xz)=12βt2xαtz2+const.\log p_t(x\mid z) = -\frac{1}{2\beta_t^2}\|x-\alpha_tz\|^2+\text{const}.

xx 求梯度。因为

xxαtz2=2(xαtz),\nabla_x\|x-\alpha_tz\|^2=2(x-\alpha_tz),

所以

xlogpt(xz)=xαtzβt2=αtzxβt2.\nabla_x\log p_t(x\mid z) = -\frac{x-\alpha_tz}{\beta_t^2} = \frac{\alpha_tz-x}{\beta_t^2}.

这个 score 指向均值 αtz\alpha_tz。如果 xx 偏离均值,score 会把它往均值方向拉;βt\beta_t 越小,分布越尖锐,同样偏离下 score 的幅度越大。

如果使用采样形式

xt=αtz+βtϵ,ϵN(0,Id),x_t=\alpha_tz+\beta_t\epsilon, \qquad \epsilon\sim\mathcal N(0,I_d),

那么

xlogpt(xtz)=ϵβt.\nabla_x\log p_t(x_t\mid z) = -\frac{\epsilon}{\beta_t}.

这说明,在 Gaussian path 上,conditional score 和噪声 ϵ\epsilon 是同一信息的不同缩放。这也是后面 diffusion 训练可以做 noise prediction 的原因。

Gaussian path 下 score 与 velocity 的转换

Gaussian path 的 conditional velocity 已经在 Flow Matching 中算过:

uttarget(xz)=(α˙tβ˙tβtαt)z+β˙tβtx.u_t^{\text{target}}(x\mid z) = \left( \dot\alpha_t - \frac{\dot\beta_t}{\beta_t}\alpha_t \right)z + \frac{\dot\beta_t}{\beta_t}x.

而 conditional score 是

logpt(xz)=αtzxβt2.\nabla\log p_t(x\mid z) = \frac{\alpha_tz-x}{\beta_t^2}.

这两个对象都是 x,zx,z 的线性函数,所以可以互相转换。定义

at=βt2α˙tαtβ˙tβt,bt=α˙tαt.a_t = \beta_t^2\frac{\dot\alpha_t}{\alpha_t} - \dot\beta_t\beta_t, \qquad b_t = \frac{\dot\alpha_t}{\alpha_t}.

那么

uttarget(xz)=atlogpt(xz)+btx.u_t^{\text{target}}(x\mid z) = a_t\nabla\log p_t(x\mid z)+b_tx.

验证这个公式只需要代入 score:

atlogpt(xz)+btx=atαtzxβt2+btx.a_t\nabla\log p_t(x\mid z)+b_tx = a_t\frac{\alpha_tz-x}{\beta_t^2}+b_tx.

把它按 zzxx 分组:

=atαtβt2z+(btatβt2)x.= \frac{a_t\alpha_t}{\beta_t^2}z + \left( b_t-\frac{a_t}{\beta_t^2} \right)x.

at,bta_t,b_t 的定义,

atαtβt2=α˙tβ˙tβtαt,\frac{a_t\alpha_t}{\beta_t^2} = \dot\alpha_t - \frac{\dot\beta_t}{\beta_t}\alpha_t,

并且

btatβt2=β˙tβt.b_t-\frac{a_t}{\beta_t^2} = \frac{\dot\beta_t}{\beta_t}.

所以右边正好变回

uttarget(xz).u_t^{\text{target}}(x\mid z).

同样的转换对 marginal 版本也成立:

uttarget(x)=atlogpt(x)+btx.u_t^{\text{target}}(x) = a_t\nabla\log p_t(x)+b_tx.

原因是 conditional 公式两边对 zxz\mid x 做 posterior average 后,左边变成 marginal velocity,右边的 conditional score 变成 marginal score,而 btxb_tx 不依赖 zz

这一步是本节的关键桥梁:在 Gaussian probability path 下,Flow Matching 的 velocity 和 diffusion 的 score 不是两套无关对象,而是同一信息的两种线性参数化。

Denoiser 参数化

除了 velocity 和 score,还有一种常见参数化:denoiser。conditional denoiser 定义为

Dt(xz)=z.D_t(x\mid z)=z.

marginal denoiser 是 posterior mean:

Dt(x)=E[zx]=zpt(xz)pdata(z)pt(x)dz.D_t(x) = \mathbb E[z\mid x] = \int z \frac{p_t(x\mid z)p_{\text{data}}(z)}{p_t(x)} \,dz.

它表示:给定 noisy sample xx 后,对背后的 clean data zz 的平均估计。

因此,在 Gaussian path 下,velocity、score、denoiser、noise prediction 都在表达同一件事:给定 noisy sample xx,恢复或利用它背后的 clean data 信息。Diffusion models 常学习 score 或 noise;Flow Matching 常学习 velocity;denoising models 常学习 denoiser。

4.2 Sampling with SDEs

到目前为止,我们知道如何用 ODE 沿着 probability path 采样。只要有 marginal vector field

uttarget(x),u_t^{\text{target}}(x),

模拟

dXt=uttarget(Xt)dtdX_t=u_t^{\text{target}}(X_t)\,dt

就能保证

Xtpt.X_t\sim p_t.

这对应的是 flow model。Diffusion model 还需要处理 SDE,也就是在动力系统里加入随机噪声:

dXt=ut(Xt)dt+σtdWt.dX_t = u_t(X_t)\,dt+\sigma_t\,dW_t.

问题在于,直接往 ODE 里加 Brownian noise 会改变样本分布。噪声会把概率质量扩散开,原来的 probability path 不会自动保持不变。Theorem 17 给出的结论是:可以加噪声,但要同时加上一项由 score 控制的 drift,用来抵消噪声对 marginal distribution 的影响。

Figure 9

Figure 9 和前面 Figure 6 很像,但现在轨迹变成了 SDE 轨迹。右边的轨迹不再平滑,而是带有随机抖动。关键是:虽然单条轨迹变随机了,不再由初始点完全决定,但每个时间 tt 的样本分布仍然沿着同一条 probability path。

Theorem 17: SDE Extension Trick

设原来的 marginal vector field 是

uttarget(x),u_t^{\text{target}}(x),

并且它对应的 ODE 可以让样本分布满足

Xtpt.X_t\sim p_t.

那么对任意 diffusion coefficient

σt0,\sigma_t\ge 0,

可以构造下面这个 SDE:

dXt=[uttarget(Xt)+σt22logpt(Xt)]dt+σtdWt.dX_t = \left[ u_t^{\text{target}}(X_t) + \frac{\sigma_t^2}{2} \nabla\log p_t(X_t) \right]dt + \sigma_t\,dW_t.

这个 SDE 仍然满足

Xtpt,0t1.X_t\sim p_t, \qquad 0\le t\le 1.

特别地,

X1pdata.X_1\sim p_{\text{data}}.

这条公式里有两部分 drift:

uttarget(Xt)u_t^{\text{target}}(X_t)

是原来 flow 的速度场,负责沿着 probability path 往前推进;

σt22logpt(Xt)\frac{\sigma_t^2}{2}\nabla\log p_t(X_t)

是新增的 score drift。它的作用是配合 Brownian noise,使加噪后的过程仍然保持正确的 marginal distribution。

如果只加

σtdWt,\sigma_t\,dW_t,

分布会被扩散项摊开;score drift 会把样本往高密度区域拉回来。两者配合后,SDE 的 marginal 仍然是 ptp_t

因此 Theorem 17 可以这样理解:

在同一条 probability path 上,不只有一条 ODE 可以实现它;也可以有一族 SDE 实现它。σt\sigma_t 控制轨迹有多随机,score term 负责让随机性不破坏目标 marginal path。

σt=0\sigma_t=0

时,SDE 退化回原来的 ODE:

dXt=uttarget(Xt)dt.dX_t=u_t^{\text{target}}(X_t)\,dt.

所以 flow sampling 是 SDE sampling 的特殊情况。

Gaussian path 下的 SDE 写法

对 Gaussian probability path,我们已经知道 velocity 和 score 可以互相转换:

uttarget(x)=atlogpt(x)+btx.u_t^{\text{target}}(x) = a_t\nabla\log p_t(x)+b_tx.

因此 Theorem 17 的 SDE 可以写成纯 score 形式:

dXt=[(at+σt22)logpt(Xt)+btXt]dt+σtdWt.dX_t = \left[ \left(a_t+\frac{\sigma_t^2}{2}\right) \nabla\log p_t(X_t) + b_tX_t \right]dt + \sigma_t\,dW_t.

这说明一旦学到了 score,就可以用不同的 σt\sigma_t 构造不同的 SDE sampler。理论上任意 σt0\sigma_t\ge 0 都能保持同一个 marginal path;实践中因为神经网络有误差、数值模拟也有误差,σt\sigma_t 的选择会影响采样效果。

这也解释了为什么 diffusion sampling 有很多 sampler 变体:它们常常对应不同的 SDE/ODE 采样方式,但背后使用的是同一个 score 或与 score 等价的参数化。

Fokker-Planck equation

Fokker-Planck equation 是把“单个粒子的 SDE”翻译成“一群粒子的密度演化”的方程。

SDE 描述的是一条随机轨迹:

dXt=ut(Xt)dt+σtdWt.dX_t=u_t(X_t)\,dt+\sigma_t\,dW_t.

如果只看一个粒子,看到的是它在速度场和随机噪声共同作用下怎么走。但生成模型关心的是很多粒子的整体分布。假设有一大群粒子都按这个 SDE 运动,它们在时间 tt 的空间密度记作

pt(x).p_t(x).

Fokker-Planck equation 说的就是:这些粒子的密度 pt(x)p_t(x) 会怎样随时间变化。

对于 SDE

dXt=ut(Xt)dt+σtdWt,dX_t=u_t(X_t)\,dt+\sigma_t\,dW_t,

如果

Xtpt,X_t\sim p_t,

那么 ptp_t 满足

tpt(x)=div(ptut)(x)+σt22Δpt(x).\partial_t p_t(x) = -\operatorname{div}(p_tu_t)(x) + \frac{\sigma_t^2}{2}\Delta p_t(x).

右边有两种机制。

第一项

div(ptut)(x)-\operatorname{div}(p_tu_t)(x)

是 drift / transport 项。它描述速度场 utu_t 如何搬运概率质量。可以把它想成水流推动墨水:墨水本身没有凭空产生或消失,只是被流场搬到别的位置。

如果没有 Brownian noise,也就是 σt=0\sigma_t=0,Fokker-Planck equation 退化成 continuity equation:

tpt(x)=div(ptut)(x).\partial_t p_t(x) = -\operatorname{div}(p_tu_t)(x).

第二项

σt22Δpt(x).\frac{\sigma_t^2}{2}\Delta p_t(x).

是 diffusion 项。它来自 Brownian noise,描述随机碰撞如何把概率质量向周围摊开。物理上可以把它想成热扩散或墨水在静止水里慢慢散开:高密度区域会被抹平,低密度区域会被填起来。

因此 Fokker-Planck equation 的物理本质是:

单个粒子的随机运动规则,诱导出整体粒子密度的演化规律。

简洁地说:

SDE: 一个粒子怎么动
Fokker-Planck: 一群粒子的密度怎么变

Theorem 17 的证明思路

原来的 ODE 已经满足 continuity equation:

tpt(x)=div(ptuttarget)(x).\partial_t p_t(x) = -\operatorname{div}(p_tu_t^{\text{target}})(x).

现在希望构造一个 SDE drift:

u~t(x)=uttarget(x)+σt22logpt(x).\tilde u_t(x) = u_t^{\text{target}}(x) + \frac{\sigma_t^2}{2}\nabla\log p_t(x).

需要证明它满足 Fokker-Planck equation:

tpt(x)=div(ptu~t)(x)+σt22Δpt(x).\partial_t p_t(x) = -\operatorname{div}(p_t\tilde u_t)(x) + \frac{\sigma_t^2}{2}\Delta p_t(x).

从 continuity equation 出发:

tpt(x)=div(ptuttarget)(x).\partial_t p_t(x) = -\operatorname{div}(p_tu_t^{\text{target}})(x).

加上再减去同一个扩散项:

tpt(x)=div(ptuttarget)(x)σt22Δpt(x)+σt22Δpt(x).\partial_t p_t(x) = -\operatorname{div}(p_tu_t^{\text{target}})(x) - \frac{\sigma_t^2}{2}\Delta p_t(x) + \frac{\sigma_t^2}{2}\Delta p_t(x).

利用

Δpt=div(pt),\Delta p_t=\operatorname{div}(\nabla p_t),

以及

pt=ptlogpt,\nabla p_t=p_t\nabla\log p_t,

中间的负扩散项可以写成

σt22Δpt=div(ptσt22logpt).- \frac{\sigma_t^2}{2}\Delta p_t = - \operatorname{div} \left( p_t\frac{\sigma_t^2}{2} \nabla\log p_t \right).

于是

tpt(x)=div(pt[uttarget+σt22logpt])(x)+σt22Δpt(x).\partial_t p_t(x) = -\operatorname{div} \left( p_t \left[ u_t^{\text{target}} + \frac{\sigma_t^2}{2}\nabla\log p_t \right] \right)(x) + \frac{\sigma_t^2}{2}\Delta p_t(x).

这正是 SDE drift

uttarget+σt22logptu_t^{\text{target}} + \frac{\sigma_t^2}{2}\nabla\log p_t

对应的 Fokker-Planck equation。因此这个 SDE 的 marginal distribution 仍然是 ptp_t

Langevin dynamics

Theorem 17 有一个重要特例。如果 probability path 不随时间变化,也就是

pt=p,p_t=p,

并且取

uttarget=0,u_t^{\text{target}}=0,

那么 Theorem 17 给出:

dXt=σt22logp(Xt)dt+σtdWt.dX_t = \frac{\sigma_t^2}{2}\nabla\log p(X_t)\,dt + \sigma_t\,dW_t.

这就是 Langevin dynamics。

Langevin dynamics 可以看成一种物理平衡机制:粒子一边被随机噪声推动,一边被 score field 拉向目标分布的高密度区域。

如果只有噪声项

dXt=σtdWt,dX_t=\sigma_t\,dW_t,

粒子会不断扩散,分布越来越摊开。

如果只有 score drift

dXt=σt22logp(Xt)dt,dX_t=\frac{\sigma_t^2}{2}\nabla\log p(X_t)\,dt,

粒子会沿着 log-density 上升方向移动,趋向高概率区域,但缺少随机探索。

Langevin dynamics 把两者放在一起:

Brownian noise 负责探索
score drift 负责拉回高概率区域

当这两种作用达到平衡时,目标分布 pp 就成为 stationary distribution。

Figure 10 Langevin dynamics

Figure 10 中,目标分布 p(x)p(x) 是一个有 5 个 mode 的 Gaussian mixture。上排黑点是 Langevin dynamics 中的粒子,底排是这些粒子的密度估计。最开始粒子比较分散,随着时间推进,粒子逐渐集中到蓝色高密度区域附近。

这张图体现了 Langevin dynamics 的两个作用:

当这两个作用达到平衡时,pp 就成为 stationary distribution。也就是说,如果一开始

X0p,X_0\sim p,

那么 Langevin dynamics 会保持

Xtp.X_t\sim p.

如果一开始不是 pp,在一定条件下样本分布也会逐渐靠近 pp

这也是它和 diffusion models 的连接。Diffusion sampling 不是简单地“沿着 score 爬坡”,也不是简单地“加噪声乱走”,而是在随机探索和 score 引导之间保持平衡,让整体分布按照我们想要的路径演化。

这一小节可以收成一句话:

ODE 用 velocity 沿着 probability path 推动样本;SDE 可以在此基础上加入 Brownian noise,但必须同时加入 score drift,才能让每个时间的 marginal distribution 仍然等于指定的 ptp_t

4.3 Score Matching

4.2 说明了如果知道 marginal score

logpt(x),\nabla\log p_t(x),

就可以把它放进 SDE sampler:

dXt=[uttarget(Xt)+σt22logpt(Xt)]dt+σtdWt.dX_t = \left[ u_t^{\text{target}}(X_t) + \frac{\sigma_t^2}{2}\nabla\log p_t(X_t) \right]dt + \sigma_t\,dW_t.

现在剩下的问题是:这个 marginal score 怎么学?

和 Flow Matching 一样,理想目标是直接回归 marginal object。定义 score network:

stθ:Rd×[0,1]Rd.s_t^\theta:\mathbb R^d\times[0,1]\to\mathbb R^d.

它希望近似

stθ(x)logpt(x).s_t^\theta(x)\approx \nabla\log p_t(x).

最直接的 score matching loss 是

LSM(θ)=EtUnif, zpdata, xpt(z)[stθ(x)logpt(x)2].L_{\text{SM}}(\theta) = \mathbb E_{t\sim\operatorname{Unif},\ z\sim p_{\text{data}},\ x\sim p_t(\cdot\mid z)} \left[ \left\| s_t^\theta(x)-\nabla\log p_t(x) \right\|^2 \right].

这里虽然采样 xx 没有问题,但 target

logpt(x)\nabla\log p_t(x)

不可直接计算,因为它是 marginal score,需要对所有可能的 zz 做 posterior average。

于是定义 conditional score matching loss:

LCSM(θ)=Et,z,xpt(z)[stθ(x)logpt(xz)2].L_{\text{CSM}}(\theta) = \mathbb E_{t,z,x\sim p_t(\cdot\mid z)} \left[ \left\| s_t^\theta(x)-\nabla\log p_t(x\mid z) \right\|^2 \right].

这个 loss 可计算,因为 conditional score 由我们选择的 probability path 决定,通常有解析形式。

Theorem 22 说:

LSM(θ)=LCSM(θ)+C,L_{\text{SM}}(\theta) = L_{\text{CSM}}(\theta)+C,

其中 CCθ\theta 无关,因此

θLSM(θ)=θLCSM(θ).\nabla_\theta L_{\text{SM}}(\theta) = \nabla_\theta L_{\text{CSM}}(\theta).

这个结论和 Theorem 12 完全平行。原因是

logpt(x)=Ezx[logpt(xz)].\nabla\log p_t(x) = \mathbb E_{z\mid x} \left[ \nabla\log p_t(x\mid z) \right].

也就是说,conditional score 的 posterior average 就是 marginal score。用 MSE 训练时,网络只能看到 xx,所以最优预测就是对所有可能 zz 的条件平均。

因此:

回归可计算的 conditional score,等价于学习不可直接计算的 marginal score。

Gaussian path 下的 denoising score matching

对 Gaussian probability path:

pt(xz)=N(αtz,βt2Id),p_t(x\mid z) = \mathcal N(\alpha_tz,\beta_t^2I_d),

4.1 已经算出 conditional score:

logpt(xz)=xαtzβt2.\nabla\log p_t(x\mid z) = -\frac{x-\alpha_tz}{\beta_t^2}.

把它代入 conditional score matching loss:

LCSM(θ)=E[stθ(x)+xαtzβt22].L_{\text{CSM}}(\theta) = \mathbb E \left[ \left\| s_t^\theta(x) + \frac{x-\alpha_tz}{\beta_t^2} \right\|^2 \right].

现在用 Gaussian path 的采样形式:

xt=αtz+βtϵ,ϵN(0,Id).x_t=\alpha_tz+\beta_t\epsilon, \qquad \epsilon\sim\mathcal N(0,I_d).

xtαtz=βtϵ.x_t-\alpha_tz=\beta_t\epsilon.

所以 conditional score 是

logpt(xtz)=ϵβt.\nabla\log p_t(x_t\mid z) = -\frac{\epsilon}{\beta_t}.

于是 loss 变成

LCSM(θ)=Et,z,ϵ[stθ(αtz+βtϵ)+ϵβt2].L_{\text{CSM}}(\theta) = \mathbb E_{t,z,\epsilon} \left[ \left\| s_t^\theta(\alpha_tz+\beta_t\epsilon) + \frac{\epsilon}{\beta_t} \right\|^2 \right].

这就是 denoising score matching。它的名字来自这里:我们先用噪声 ϵ\epsilon 污染数据 zz,得到 xtx_t;然后训练网络从 noisy sample 中恢复与噪声相关的信息。

Noise prediction 参数化

上面的 score target 是

ϵβt.-\frac{\epsilon}{\beta_t}.

βt\beta_t 很小时,这个量会变得很大,训练可能不稳定。早期 DDPM 的做法是改成 noise prediction。定义噪声预测网络

ϵtθ(x)=βtstθ(x).\epsilon_t^\theta(x) = -\beta_t s_t^\theta(x).

如果

stθ(xt)ϵβt,s_t^\theta(x_t)\approx -\frac{\epsilon}{\beta_t},

那么

ϵtθ(xt)=βtstθ(xt)ϵ.\epsilon_t^\theta(x_t) = -\beta_t s_t^\theta(x_t) \approx \epsilon.

因此可以直接训练网络预测加进去的噪声:

LDDPM(θ)=Et,z,ϵ[ϵtθ(αtz+βtϵ)ϵ2].L_{\text{DDPM}}(\theta) = \mathbb E_{t,z,\epsilon} \left[ \left\| \epsilon_t^\theta(\alpha_tz+\beta_t\epsilon)-\epsilon \right\|^2 \right].

这个形式就是 diffusion models 中非常常见的噪声预测 loss。它和 score matching 的目标等价,只是参数化更稳定。

Algorithm 4

Algorithm 4 的训练流程可以读成:

  1. 采样真实数据
zpdata.z\sim p_{\text{data}}.
  1. 采样时间
tUnif[0,1].t\sim\operatorname{Unif}[0,1].
  1. 采样噪声
ϵN(0,Id).\epsilon\sim\mathcal N(0,I_d).
  1. 构造 noisy sample
xt=αtz+βtϵ.x_t=\alpha_tz+\beta_t\epsilon.
  1. 可以训练 score network:
stθ(xt)+ϵβt2.\left\| s_t^\theta(x_t)+\frac{\epsilon}{\beta_t} \right\|^2.

也可以训练 noise predictor:

ϵtθ(xt)ϵ2.\left\| \epsilon_t^\theta(x_t)-\epsilon \right\|^2.

Summary 24:Score matching 与 stochastic sampling

Summary 24

Summary 24 continued

这一节可以收成三个结论。

第一,conditional score 和 marginal score 分别是

logpt(xz),logpt(x).\nabla\log p_t(x\mid z), \qquad \nabla\log p_t(x).

第二,如果知道 marginal vector field 和 marginal score,那么对任意 σt0\sigma_t\ge 0,SDE

dXt=[uttarget(Xt)+σt22logpt(Xt)]dt+σtdWtdX_t = \left[ u_t^{\text{target}}(X_t) + \frac{\sigma_t^2}{2}\nabla\log p_t(X_t) \right]dt + \sigma_t\,dW_t

都会沿着同一条 probability path:

Xtpt.X_t\sim p_t.

第三,marginal score 可以通过 denoising score matching 学到:

LCSM(θ)=Ez,t,xpt(z)[stθ(x)logpt(xz)2].L_{\text{CSM}}(\theta) = \mathbb E_{z,t,x\sim p_t(\cdot\mid z)} \left[ \left\| s_t^\theta(x)-\nabla\log p_t(x\mid z) \right\|^2 \right].

对 Gaussian path,还可以在 score 和 velocity 之间转换:

utθ(x)=atstθ(x)+btx.u_t^\theta(x) = a_ts_t^\theta(x)+b_tx.

因此学 score、学 velocity、学 noise,在 Gaussian path 下都是同一个模型能力的不同表达。训练完成后,可以选择 ODE sampler,也可以选择 SDE sampler。


5. Guidance: How To Condition on a Prompt

前面讨论的生成模型都是 unguided:模型从

pdata(z)p_{\text{data}}(z)

中采样,也就是生成某个来自数据分布的对象。但实际使用生成模型时,我们通常不想生成“任意图片”,而是想生成满足某个条件的对象。例如给定文本 prompt

y=“corgi dog”,y=\text{``corgi dog''},

希望生成符合这个 prompt 的图像。

数学上,目标从无条件采样变成条件采样:

zpdata(y).z\sim p_{\text{data}}(\cdot\mid y).

这里使用 guided 这个词,而不是 conditional。原因是前面已经用 conditional probability path / conditional vector field 表示“给定数据点 zz”的对象。为了避免混淆,这一节把“给定 prompt 或 label yy”称为 guided generation。

5.1 Vanilla Guidance

最直接的 guided model 做法很简单:把 prompt 或 label yy 作为网络输入的一部分。

在 unguided diffusion model 里,神经网络是

uθ:Rd×[0,1]Rd,(x,t)utθ(x).u^\theta:\mathbb R^d\times[0,1]\to\mathbb R^d, \qquad (x,t)\mapsto u_t^\theta(x).

guided diffusion model 则变成

uθ:Rd×Y×[0,1]Rd,u^\theta:\mathbb R^d\times\mathcal Y\times[0,1]\to\mathbb R^d,

也就是

(x,y,t)utθ(xy).(x,y,t)\mapsto u_t^\theta(x\mid y).

这里 Y\mathcal Y 是条件变量所在的空间。如果 yy 是文本 prompt,那么 Y\mathcal Y 是所有文本的集合;如果 yy 是类别标签,那么 Y\mathcal Y 是离散标签集合。

采样时,给定某个 prompt yy,从简单噪声初始化:

X0pinit,X_0\sim p_{\text{init}},

然后模拟 guided SDE:

dXt=utθ(Xty)dt+σtdWt.dX_t = u_t^\theta(X_t\mid y)\,dt + \sigma_t\,dW_t.

目标是

X1pdata(y).X_1\sim p_{\text{data}}(\cdot\mid y).

如果

σt=0,\sigma_t=0,

则得到 guided flow model:

dXt=utθ(Xty)dt.dX_t=u_t^\theta(X_t\mid y)\,dt.

为了让记号更简洁,这里主要用 flow matching 来说明,但同样思想也适用于 diffusion / score matching。

Guided flow matching objective

训练 guided flow model 时,数据不再只是

zpdata,z\sim p_{\text{data}},

而是成对样本:

(z,y)pdata(z,y).(z,y)\sim p_{\text{data}}(z,y).

例如图像和文字描述,或者图像和类别标签。

如果固定某个 yy,目标分布就是

pdata(zy).p_{\text{data}}(z\mid y).

此时可以像无条件 flow matching 一样训练:

Ezpdata(y), xpt(z)[utθ(xy)uttarget(xz)2].\mathbb E_{z\sim p_{\text{data}}(\cdot\mid y),\ x\sim p_t(\cdot\mid z)} \left[ \left\| u_t^\theta(x\mid y) - u_t^{\text{target}}(x\mid z) \right\|^2 \right].

把所有可能的 yy 一起考虑,就得到 guided conditional flow matching objective:

LCFMguided(θ)=E(z,y)pdata(z,y), tUnif[0,1], xpt(z)[utθ(xy)uttarget(xz)2].L_{\text{CFM}}^{\text{guided}}(\theta) = \mathbb E_{(z,y)\sim p_{\text{data}}(z,y),\ t\sim\operatorname{Unif}[0,1],\ x\sim p_t(\cdot\mid z)} \left[ \left\| u_t^\theta(x\mid y) - u_t^{\text{target}}(x\mid z) \right\|^2 \right].

这个公式和无条件版本的区别只有一个:网络输入多了 yy,训练数据从单个 zz 变成 pair

(z,y).(z,y).

注意,conditional probability path

pt(z)p_t(\cdot\mid z)

以及 conditional vector field

uttarget(xz)u_t^{\text{target}}(x\mid z)

通常不需要依赖 yy。因为 zz 已经是具体数据样本,路径描述的是如何从噪声走到这个 zzyy 的作用主要是告诉神经网络:在看到 noisy sample xx 时,应该使用哪种条件下的 vector field。

PyTorch 实现上,这意味着 dataloader 不再只返回 images / latents:

z

而是返回 paired batch:

(z, y)

然后模型调用从

u_theta(x, t)

变成

u_theta(x, y, t)

Vanilla guidance 的局限

理论上,vanilla guidance 应该可以学到

pdata(y).p_{\text{data}}(\cdot\mid y).

但实际图像生成中,人们发现这种方式生成的样本有时并不够符合 prompt。Figure 11 展示了这个问题。

Figure 11 guidance corgi

左边是 vanilla guidance 生成的样本,prompt/class 是 “corgi dog”,但有些图像并不像 corgi。右边使用了更强的 guidance 后,样本明显更符合 prompt。

出现这个问题有很多原因:

因此,仅仅把 yy 输入网络通常还不够。下一节要解决的问题是:如何人为强化 prompt 的影响,让生成结果更贴合 yy。这就是 classifier-free guidance 的动机。

5.2 Classifier-Free Guidance

Vanilla guidance 只是把 yy 输入网络,让模型学习

uttarget(xy).u_t^{\text{target}}(x\mid y).

理论上这应该能采样

pdata(y),p_{\text{data}}(\cdot\mid y),

但实践中 prompt adherence 往往不够强。Classifier-free guidance 的目的,就是在采样时人为放大“和 prompt 有关的那部分方向”。

Figure 12 classifier-free guidance

Figure 12 画出了两种思路。上排是 classifier guidance:把 guided vector field 拆成 unguided part 和 classifier 提供的 prompt-dependent part,然后放大后者。下排是 classifier-free guidance:不训练额外 classifier,而是用 guided vector field 和 unguided vector field 的差来表示 prompt-dependent part,再把这个差放大。

Classifier guidance:先看 prompt-dependent part 从哪里来

先考虑 Gaussian probability path。上一节已经知道,guided vector field 可以用 guided score 写成

uttarget(xy)=atlogpt(xy)+btx.u_t^{\text{target}}(x\mid y) = a_t\nabla\log p_t(x\mid y)+b_tx.

这里的 pt(xy)p_t(x\mid y) 是在条件 yy 下,时间 tt 的 noisy sample 分布。

用 Bayes’ rule:

pt(xy)=pt(x)pt(yx)pt(y).p_t(x\mid y) = \frac{p_t(x)p_t(y\mid x)}{p_t(y)}.

xxlog\nabla\log

logpt(xy)=logpt(x)+logpt(yx).\nabla\log p_t(x\mid y) = \nabla\log p_t(x) + \nabla\log p_t(y\mid x).

这里没有 logpt(y)\nabla\log p_t(y),因为梯度是对 xx 求的,而 pt(y)p_t(y) 不依赖 xx

代回 guided vector field:

uttarget(xy)=btx+at(logpt(x)+logpt(yx)).u_t^{\text{target}}(x\mid y) = b_tx + a_t \left( \nabla\log p_t(x) + \nabla\log p_t(y\mid x) \right).

而 unguided vector field 是

uttarget(x)=atlogpt(x)+btx.u_t^{\text{target}}(x) = a_t\nabla\log p_t(x)+b_tx.

所以

uttarget(xy)=uttarget(x)+atlogpt(yx).u_t^{\text{target}}(x\mid y) = u_t^{\text{target}}(x) + a_t\nabla\log p_t(y\mid x).

这个公式很有解释力:

其中 pt(yx)p_t(y\mid x) 可以理解成一个 classifier:给定 noisy sample xx,判断它属于 prompt / class yy 的可能性。因此

logpt(yx)\nabla\log p_t(y\mid x)

就是“怎样移动 xx,能让它更像 label yy”的方向。

如果 vanilla guidance 对 prompt 的响应不够强,一个自然想法是把这部分放大:

u~t(xy)=uttarget(x)+watlogpt(yx),w>1.\tilde u_t(x\mid y) = u_t^{\text{target}}(x) + w a_t\nabla\log p_t(y\mid x), \qquad w>1.

这里 ww 叫 guidance scale。w=1w=1 时回到普通 guided vector field;w>1w>1 时,prompt-dependent part 被放大。

这就是 classifier guidance。问题是它需要额外训练一个 classifier 来估计

pt(yx),p_t(y\mid x),

并且如果 yy 是文本 prompt,而不是简单类别标签,这个 classifier 会很难训练。

Classifier-free guidance:不用 classifier 也能放大 prompt

Classifier-free guidance 的关键是把 classifier term 消掉。

从刚才的分解:

uttarget(xy)=uttarget(x)+atlogpt(yx).u_t^{\text{target}}(x\mid y) = u_t^{\text{target}}(x) + a_t\nabla\log p_t(y\mid x).

因此

atlogpt(yx)=uttarget(xy)uttarget(x).a_t\nabla\log p_t(y\mid x) = u_t^{\text{target}}(x\mid y) - u_t^{\text{target}}(x).

也就是说,prompt-dependent part 可以直接看成:

guided vector fieldunguided vector field.\text{guided vector field} - \text{unguided vector field}.

把这个差放大:

u~t(xy)=uttarget(x)+w(uttarget(xy)uttarget(x)).\tilde u_t(x\mid y) = u_t^{\text{target}}(x) + w \left( u_t^{\text{target}}(x\mid y) - u_t^{\text{target}}(x) \right).

整理:

u~t(xy)=(1w)uttarget(x)+wuttarget(xy).\tilde u_t(x\mid y) = (1-w)u_t^{\text{target}}(x) + w u_t^{\text{target}}(x\mid y).

这就是 classifier-free guidance 的核心公式。

它的意思是:

用 guided vector field 减去 unguided vector field 得到 prompt-dependent part,然后把这部分放大。

w=1w=1

u~t(xy)=uttarget(xy).\tilde u_t(x\mid y) = u_t^{\text{target}}(x\mid y).

w>1w>1 时,采样方向会比普通 guided model 更强调 yy。这通常会提高 prompt adherence,但也可能降低多样性,甚至让样本变得过度刻板。

需要注意:当 w1w\ne 1 时,

u~t(xy)\tilde u_t(x\mid y)

不再是严格意义上的真实 guided vector field。它是一个经验上非常有效的 heuristic。

一个模型同时学 guided 和 unguided

公式里需要两个对象:

uttarget(xy)u_t^{\text{target}}(x\mid y)

uttarget(x).u_t^{\text{target}}(x).

如果分别训练两个模型,成本会很高。CFG 的做法是引入一个特殊空条件:

.\emptyset.

把 unguided vector field 写成

uttarget(x)=uttarget(x).u_t^{\text{target}}(x) = u_t^{\text{target}}(x\mid \emptyset).

这样一个网络

utθ(xy)u_t^\theta(x\mid y)

既能在输入真实 yy 时学习 guided vector field,也能在输入 \emptyset 时学习 unguided vector field。

训练时,随机把一部分 label 丢掉:

y.y\leftarrow \emptyset.

丢 label 的概率记作

η.\eta.

于是训练目标为

LCFMCFG(θ)=E[utθ(xy)uttarget(xz)2],L_{\text{CFM}}^{\text{CFG}}(\theta) = \mathbb E \left[ \left\| u_t^\theta(x\mid y) - u_t^{\text{target}}(x\mid z) \right\|^2 \right],

其中采样过程是:

(z,y)pdata(z,y),tUnif[0,1],xpt(z),(z,y)\sim p_{\text{data}}(z,y), \qquad t\sim\operatorname{Unif}[0,1], \qquad x\sim p_t(\cdot\mid z),

并且以概率 η\etayy 替换成 \emptyset

Algorithm 5 CFG training

对 Gaussian path:

x=αtz+βtϵ,x=\alpha_tz+\beta_t\epsilon,

target 是

α˙tz+β˙tϵ.\dot\alpha_tz+\dot\beta_t\epsilon.

所以 Algorithm 5 就是:

  1. 采样 paired data (z,y)(z,y)
  2. 采样时间 tt
  3. 采样噪声 ϵ\epsilon
  4. 构造 x=αtz+βtϵx=\alpha_tz+\beta_t\epsilon
  5. 以概率 η\eta 丢掉 label,令 y=y=\emptyset
  6. 训练
utθ(xy)(α˙tz+β˙tϵ)2.\left\| u_t^\theta(x\mid y) - (\dot\alpha_tz+\dot\beta_t\epsilon) \right\|^2.

推理时怎么用 CFG

训练完成后,采样时给定真实 prompt yy。每一步都计算两个网络输出:

Summary 27 CFG

utθ(xy)u_t^\theta(x\mid y)

utθ(x).u_t^\theta(x\mid \emptyset).

然后组合成 CFG vector field:

u~tθ(xy)=(1w)utθ(x)+wutθ(xy).\tilde u_t^\theta(x\mid y) = (1-w)u_t^\theta(x\mid \emptyset) + w u_t^\theta(x\mid y).

也常写成:

u~tθ(xy)=utθ(x)+w(utθ(xy)utθ(x)).\tilde u_t^\theta(x\mid y) = u_t^\theta(x\mid \emptyset) + w \left( u_t^\theta(x\mid y) - u_t^\theta(x\mid \emptyset) \right).

第二种写法更直观:从 unguided direction 出发,加上放大后的 prompt-dependent difference。

然后用这个 u~tθ\tilde u_t^\theta 模拟 ODE:

dXt=u~tθ(Xty)dt.dX_t=\tilde u_t^\theta(X_t\mid y)\,dt.

如果是 diffusion model,也可以把原来的 vector field 替换成 u~tθ\tilde u_t^\theta,再用 SDE sampler。

Guidance scale 的效果

Figure 13 CFG MNIST

Figure 13 展示了不同 guidance scale ww 的效果。w=1w=1 时就是普通 guided model;w=2w=2w=4w=4 时,类别条件被更强地强调,生成结果更贴近目标数字类别。

ww 不是越大越好。一般来说:

现代图像/视频生成模型大量依赖 CFG。它不是严格保持

X1pdata(y)X_1\sim p_{\text{data}}(\cdot\mid y)

的精确采样方法,而是一个经验效果极好的 prompt reinforcement heuristic。

本节可以收成一句话:

Classifier-free guidance 通过同时学习 conditional 和 unconditional vector fields,在采样时放大二者的差,从而强化 prompt 对生成方向的影响。


6. Building Large-Scale Image or Video Generators

前面几节已经把生成模型的数学核心搭起来了。无论是 flow matching 还是 diffusion,训练出来的对象都可以写成一个带参数的向量场:

utθ(xy).u_t^\theta(x\mid y).

它接收三个东西:

xRd,t[0,1],yY,x\in\mathbb R^d,\qquad t\in[0,1],\qquad y\in\mathcal Y,

输出一个和 xx 同维度的向量:

utθ(xy)Rd.u_t^\theta(x\mid y)\in\mathbb R^d.

这个向量就是采样时要沿着走的方向。对低维 toy examples 来说,把 x,t,yx,t,y 拼起来喂给一个 MLP 就够了;但图像、视频、蛋白质这样的数据有很强的结构,维度也极高,直接用普通 MLP 基本不可行。

第 6 节讨论的是:在真实的大规模图像和视频生成模型里,这个 utθ(xy)u_t^\theta(x\mid y) 到底怎么实现。

它分成三件事:

  1. 怎样把原始条件输入变成模型能处理的向量表示,例如时间 tt、类别 label、文本 prompt;
  2. 怎样设计真正的 neural network architecture,例如 Diffusion Transformer 和 U-Net;
  3. 为什么现代模型几乎都不直接在像素空间生成,而是在 autoencoder 的 latent space 里生成。

这一节的重点不是改变前面学过的数学目标,而是把数学目标落到可扩展的网络结构上。

6.1 Neural Network Architectures

模型输入和输出的形状

对 guided generation,网络形式是:

utθ(xy).u_t^\theta(x\mid y).

其中:

xRdx\in\mathbb R^d

是当前带噪样本,或者当前 latent;

t[0,1]t\in[0,1]

表示现在处在 probability path 的哪一个时间点;

yYy\in\mathcal Y

是条件变量,例如类别、文本 prompt、图像条件、视频条件等。

网络输出:

utθ(xy)Rdu_t^\theta(x\mid y)\in\mathbb R^d

必须和 xx 的形状一致。因为采样时要做的是:

dXt=utθ(Xty)dt,dX_t = u_t^\theta(X_t\mid y)\,dt,

或者在 diffusion model 里:

dXt=utθ(Xty)dt+σtdWt.dX_t = u_t^\theta(X_t\mid y)\,dt+\sigma_t\,dW_t.

因此网络不是输出一个类别概率,也不是输出一句话,而是输出“每个像素/latent 维度下一步该怎么动”。

如果 xx 是图像,通常写成:

xRC×H×W.x\in\mathbb R^{C\times H\times W}.

其中 CC 是通道数,RGB 图像里通常 C=3C=3H,WH,W 是图像高和宽。网络输出也必须是:

utθ(xy)RC×H×W.u_t^\theta(x\mid y)\in\mathbb R^{C\times H\times W}.

这就是为什么图像生成模型需要 U-Net、DiT 这类能保留空间结构的 architecture,而不是普通全连接网络。

6.1.1 Embedding the Conditioning Variables

模型需要吃进去的条件变量有三类:时间、类别、文本。它们原本的形式差别很大,因此第一步是把它们都变成连续向量。

Embedding time.

时间 tt 是一个标量。但模型需要知道的不是“一个数字”,而是这个数字在 denoising / transport 过程中代表的阶段。早期噪声很大,后期细节更多,不同时间段的行为差异可以很剧烈。

这里容易困惑的是:为什么时间 embedding 会突然和“频率”联系起来?

这里的频率不是说 diffusion 过程里真的有某个物理量在振动。它只是一个表示时间的数学工具。做法是把时间 tt 放进很多个正弦、余弦函数里:

cos(2πwt),sin(2πwt).\cos(2\pi wt), \qquad \sin(2\pi wt).

其中 ww 就叫频率。令

ϕ(t)=2πwt.\phi(t)=2\pi wt.

tt00 走到 11 时,ϕ(t)\phi(t)00 走到 2πw2\pi w。因此:

ww 越大,这个特征对 tt 的变化越敏感。两个很接近的时间点,在高频特征下也可能变得容易区分。

这里采用 Fourier features:

TimeEmb(t)=2d[cos(2πw1t)cos(2πwd/2t)sin(2πw1t)sin(2πwd/2t)]Rd.\operatorname{TimeEmb}(t) = \sqrt{\frac{2}{d}} \begin{bmatrix} \cos(2\pi w_1t)\\ \vdots\\ \cos(2\pi w_{d/2}t)\\ \sin(2\pi w_1t)\\ \vdots\\ \sin(2\pi w_{d/2}t) \end{bmatrix} \in\mathbb R^d.

频率 wiw_i 取为:

wi=wmin(wmaxwmin)i1d/21,i=1,,d/2.w_i = w_{\min} \left( \frac{w_{\max}}{w_{\min}} \right)^{\frac{i-1}{d/2-1}}, \qquad i=1,\dots,d/2.

这组频率从低到高覆盖不同时间尺度。低频让模型感知粗粒度阶段,高频让模型区分很接近的时间点。可以把每一对

(cos(2πwit),sin(2πwit))\left( \cos(2\pi w_it), \sin(2\pi w_it) \right)

想成一个在圆上转动的指针;不同的 wiw_i 是不同转速的指针。模型拿到的不是单一时间坐标,而是一组不同分辨率的时间坐标。

这一步的意义是让网络更容易学习复杂的时间依赖。若直接输入 tt,网络需要自己从一个标量里构造出复杂的非线性时间函数;Fourier features 先提供一组基础波形,后面的 MLP 或 Transformer 层只需要组合它们,就能表达更丰富的时间变化。

前面的系数

2d\sqrt{\frac{2}{d}}

让 embedding 的范数保持稳定。因为对每个频率都有

cos2(2πwit)+sin2(2πwit)=1,\cos^2(2\pi w_it)+\sin^2(2\pi w_it)=1,

所以整体 embedding 不会因为维度变大而尺度失控。

直观地说,time embedding 不是为了机械地记录 tt 的数值,而是把 tt 展开成一组不同尺度的时间信号,让网络更容易表达随时间变化的向量场:

tutθ(xy).t \longmapsto u_t^\theta(x\mid y).

Embedding class labels.

如果 yrawy_{\text{raw}} 是离散类别,比如 MNIST 里的

yraw{0,1,,9},y_{\text{raw}}\in\{0,1,\dots,9\},

最自然的做法是为每个类别学习一个 embedding vector:

y=ClassEmb(yraw)Rd.y = \operatorname{ClassEmb}(y_{\text{raw}})\in\mathbb R^d.

这些 embedding 是模型参数的一部分,会和 utθu_t^\theta 一起训练。它的意义和 NLP 里的 word embedding 类似:类别编号本身没有几何意义,但训练后的向量会变成模型可使用的条件信息。

Embedding textual input.

如果 yrawy_{\text{raw}} 是文本 prompt,情况更复杂。文本不是一个固定维度的类别,而是一个 token sequence,并且语言里有语义、顺序、修饰关系。

现代图像生成模型通常不从零开始学习文本理解,而是使用 frozen pretrained text encoder。例如 CLIP 会把图像和文本放到同一个 embedding space 中训练:匹配的图像-文本对靠近,不匹配的对远离。

因此可以写成:

y=CLIP(yraw)RdCLIP.y=\operatorname{CLIP}(y_{\text{raw}})\in\mathbb R^{d_{\text{CLIP}}}.

但把整个 prompt 压成一个向量有时不够。比如 prompt 里有多个对象、属性、空间关系,一个单向量可能丢失 token-level 的细节。因此很多模型会使用 transformer text encoder,保留一串 token embeddings:

PromptEmbed(yraw)RS×k.\operatorname{PromptEmbed}(y_{\text{raw}}) \in \mathbb R^{S\times k}.

这里 SS 是文本 token 数,kk 是每个 token embedding 的维度。这样图像 token 可以通过 attention 去看 prompt 里的不同 token。

Figure 14 DiT and CLIP

Figure 14 右边展示的就是 CLIP 的思想:图像 encoder 和文本 encoder 被训练到同一个空间里,使匹配的 image-text pair 相似度更高,不匹配的 pair 相似度更低。这个 embedding 后面就可以作为生成模型的条件输入。

6.1.2 Diffusion Transformers

Diffusion Transformer,简称 DiT,可以理解为把 Vision Transformer 的思路放进 diffusion / flow model 里。

一张图像写成:

xRC×H×W.x\in\mathbb R^{C\times H\times W}.

Transformer 擅长处理 token sequence,所以第一步要把图像切成 patch。设 patch size 是 PP,则每个 patch 包含

C=CP2C'=CP^2

个数值。patch 的数量是:

N=HPWP.N=\frac{H}{P}\cdot\frac{W}{P}.

于是 patchify 操作把图像变成:

Patchify(x)RN×C.\operatorname{Patchify}(x)\in\mathbb R^{N\times C'}.

接着用一个可学习矩阵 WRC×dW\in\mathbb R^{C'\times d} 把每个 patch 投影到 transformer hidden dimension:

PatchEmb(x)=Patchify(x)WRN×d.\operatorname{PatchEmb}(x) = \operatorname{Patchify}(x)W \in \mathbb R^{N\times d}.

现在,模型的三个输入都变成了 transformer 能处理的形式:

t~=TimeEmb(t)Rd,\tilde t=\operatorname{TimeEmb}(t)\in\mathbb R^d, y~=PromptEmbed(y)RS×d,\tilde y=\operatorname{PromptEmbed}(y)\in\mathbb R^{S\times d}, x~0=PatchEmb(x)RN×d.\tilde x_0=\operatorname{PatchEmb}(x)\in\mathbb R^{N\times d}.

然后 DiT 用 LL 层 transformer block 反复更新图像 patch tokens:

x~i+1=DiTBlock(x~i,t~,y~),i=0,,L1.\tilde x_{i+1} = \operatorname{DiTBlock}(\tilde x_i,\tilde t,\tilde y), \qquad i=0,\dots,L-1.

最后再把 token sequence 转回图像形状:

u=Depatchify(x~LW~)RC×H×W,u = \operatorname{Depatchify}(\tilde x_L\tilde W) \in \mathbb R^{C\times H\times W},

其中

W~Rd×C.\tilde W\in\mathbb R^{d\times C'}.

这个 uu 就是模型预测的向量场:

u=utθ(xy).u=u_t^\theta(x\mid y).

DiT block 里有三类核心操作。

第一类是 patch self-attention。每个图像 patch 可以看其他 patch,从而建模全局结构。比如左上角的物体边缘可能和右下角的阴影相关,attention 可以直接建立远距离关系。

如果当前 patch tokens 是

xRN×d,x\in\mathbb R^{N\times d},

普通 self-attention 可以写成:

SelfAttn(x)=MultiHeadAttention(x,x).\operatorname{SelfAttn}(x) = \operatorname{MultiHeadAttention}(x,x).

这里两个 xx 的含义是:queries 来自图像 patch,keys 和 values 也来自图像 patch。因此每个 patch 都在图像内部寻找相关信息。

第二类是 cross-attention。图像 patch tokens 作为 queries,文本 tokens 作为 keys 和 values,使每个图像区域可以选择性地关注 prompt 的不同部分。

如果文本 embedding 是

yRS×d,y\in\mathbb R^{S\times d},

cross-attention 可以写成:

CrossAttn(x,y)=MultiHeadAttention(x,y).\operatorname{CrossAttn}(x,y) = \operatorname{MultiHeadAttention}(x,y).

这里 queries 来自图像 patch,keys 和 values 来自文本 tokens。直观地说,每个图像区域都可以根据自己当前的状态去 prompt 里找相关词。prompt 里的 “red car on the beach” 不会被压成一个单一标签,而是以 token sequence 的形式被图像 patches 反复读取。

第三类是 time conditioning。时间 tt 不只是额外输入,而是会调制每一层的归一化或残差强度。这一点是 DiT block 里最重要、也最容易被忽略的部分。

普通 Transformer block 通常会先做 LayerNorm:

LN(x).\operatorname{LN}(x).

LayerNorm 的作用是把每个 token 的 hidden features 归一化,让训练更稳定。可是 diffusion / flow model 有一个额外要求:同一个图像 token 在不同时间点应该被不同地处理。tt 靠近噪声端时,网络更偏向粗结构;tt 靠近数据端时,网络更偏向细节修正。

AdaLN 的想法是:先正常归一化,再让时间 embedding 决定这一层如何缩放和平移归一化后的特征。

(γ,β)=g(t~),(\gamma,\beta)=g(\tilde t), AdaNormt~(x)=(1+γ)Norm(x)+β.\operatorname{AdaNorm}_{\tilde t}(x) = (1+\gamma)\odot\operatorname{Norm}(x)+\beta.

其中 gg 是一个小 MLP,输入是 time embedding:

t~=TimeEmb(t).\tilde t=\operatorname{TimeEmb}(t).

输出的 γ,βRd\gamma,\beta\in\mathbb R^d 会广播到所有 patch tokens 上。也就是说,它们不是改变某一个 patch,而是改变这一层处理全部 patch 的方式。

所以 AdaLN 的物理直觉是:

时间 tt 不是一个被动附加的标签,而是在控制网络每一层的工作模式。

tt 不同时,同一个 block 里的 attention 和 MLP 会接收到不同调制后的输入。于是同一个 DiT block 可以在不同时间段表现出不同功能。

把这些操作放在一起,一个简化的 DiT block 可以写成:

xx+SelfAttn(AdaLNt~(x)),x \leftarrow x+ \operatorname{SelfAttn} \left( \operatorname{AdaLN}_{\tilde t}(x) \right), xx+CrossAttn(AdaLNt~(x),y),x \leftarrow x+ \operatorname{CrossAttn} \left( \operatorname{AdaLN}_{\tilde t}(x),y \right), xx+MLP(AdaLNt~(x)).x \leftarrow x+ \operatorname{MLP} \left( \operatorname{AdaLN}_{\tilde t}(x) \right).

Figure 14 里的 adaLN-Zero 还多了 gate。更接近图里的写法是:

xx+αattn(t~)SelfAttn(AdaLNt~(x)),x \leftarrow x+ \alpha_{\text{attn}}(\tilde t)\odot \operatorname{SelfAttn} \left( \operatorname{AdaLN}_{\tilde t}(x) \right), xx+αmlp(t~)MLP(AdaLNt~(x)).x \leftarrow x+ \alpha_{\text{mlp}}(\tilde t)\odot \operatorname{MLP} \left( \operatorname{AdaLN}_{\tilde t}(x) \right).

这里 αattn(t~)\alpha_{\text{attn}}(\tilde t)αmlp(t~)\alpha_{\text{mlp}}(\tilde t) 也是由 time embedding 产生的门控系数。它们决定这一层的 self-attention 和 MLP 输出到底加多少回残差流里。

“Zero” 通常指这些调制或 gate 在初始化时接近零,使 block 一开始接近恒等映射:

xx.x\leftarrow x.

这对深层 diffusion transformer 很有用。模型刚开始训练时不会让每一层都大幅扰动输入,而是逐渐学会在不同时间点打开哪些层、打开多少。

因此 DiT block 的完整直觉是:

Figure 14 左边展示了 DiT 的整体流程:noised latent 先被 patchify,加入 timestep 和 label conditioning,然后经过多层 DiT block,最后 reshape 回原来的 latent/image 形状。

6.1.3 U-Net

U-Net 是另一类非常重要的 diffusion architecture。它原本来自图像分割,后来在 diffusion models 里被广泛使用。

U-Net 的核心特点是:输入和输出都保持图像形状。对生成模型来说,这正好对应:

xtutθ(xty).x_t\longmapsto u_t^\theta(x_t\mid y).

举一个形状例子。如果输入是:

xtinputR3×256×256,x_t^{\text{input}}\in\mathbb R^{3\times256\times256},

encoder 会逐步降低空间分辨率、增加通道数:

xtlatent=E(xtinput)R512×32×32.x_t^{\text{latent}} = E(x_t^{\text{input}}) \in \mathbb R^{512\times32\times32}.

然后 midcoder 在低分辨率、高通道的表示上处理全局信息:

xtlatent=M(xtlatent)R512×32×32.x_t^{\text{latent}} = M(x_t^{\text{latent}}) \in \mathbb R^{512\times32\times32}.

最后 decoder 再把表示恢复到原始图像大小:

xtoutput=D(xtlatent)R3×256×256.x_t^{\text{output}} = D(x_t^{\text{latent}}) \in \mathbb R^{3\times256\times256}.

这个输出就是:

utθ(xty).u_t^\theta(x_t\mid y).

Figure 15 U-Net

Figure 15 里间的主干就是 U 形结构:左边 encoder 不断下采样,右边 decoder 不断上采样,中间是 midcoder。虚线 residual connections 把 encoder 的中间特征传给对应尺度的 decoder。

这些 skip connections 很关键。原因是:encoder 压缩图像时会丢掉一些局部细节,而 decoder 要恢复像素级输出。如果只靠最底部的 latent,恢复细节会很困难;skip connections 让 decoder 能直接拿到高分辨率特征。

时间 tt 和条件 yy 也会被 embedding 后注入到 U-Net 的不同模块里。图里虚线标出的 t,yt,y 表示这些条件不会只在输入层出现一次,而是会影响网络的多个层级。这样模型才能在不同空间尺度上都知道:现在是哪个 denoising 时间点,以及要生成什么条件下的样本。

从功能上看,DiT 和 U-Net 都是在实现同一个数学对象:

utθ(xy).u_t^\theta(x\mid y).

差别在于它们处理图像结构的方式不同:

第 6.1 节的要点可以收成一句话:

大规模 diffusion / flow model 的网络部分,就是把 x,t,yx,t,y 分别变成合适的表示,再用 U-Net 或 DiT 这样的结构预测与 xx 同形状的 vector field。

6.2 Working in Latent Space: (Variational) Autoencoders

6.1 讨论的是网络结构本身。现在的问题是:即使用了 U-Net 或 DiT,直接在高分辨率像素空间里建模仍然很贵。

一张 RGB 图像如果大小是 1024×10241024\times1024,那么它的维度是:

d=3×1024×10243×106.d=3\times1024\times1024\approx 3\times10^6.

如果是视频,还要再乘上帧数 TT

xRT×C×H×W.x\in\mathbb R^{T\times C\times H\times W}.

flow / diffusion model 和分类模型不同。分类模型可以把图像逐步压缩,最后输出一个很小的类别向量;但生成模型要输出和输入同形状的向量场:

utθ(x)Rd.u_t^\theta(x)\in\mathbb R^d.

所以如果直接在像素空间训练,模型每一步都要处理一个极大的对象。现代图像生成模型的关键做法是:不在原始像素空间生成,而是在一个压缩后的 latent space 里生成。

6.2.1 Standard Autoencoders

autoencoder 的想法很直接:用一个 encoder 把图像压缩成 latent,再用一个 decoder 把 latent 还原成图像。

设原始数据是:

xRd.x\in\mathbb R^d.

encoder 是:

μϕ:RdRk,\mu_\phi:\mathbb R^d\to\mathbb R^k,

decoder 是:

μθ:RkRd,\mu_\theta:\mathbb R^k\to\mathbb R^d,

其中

kd.k\ll d.

编码得到:

z=μϕ(x)Rk.z=\mu_\phi(x)\in\mathbb R^k.

再解码得到:

x^=μθ(z)=μθ(μϕ(x)).\hat x=\mu_\theta(z)=\mu_\theta(\mu_\phi(x)).

训练 autoencoder 的基本目标是 reconstruction loss:

LRecon(ϕ,θ)=Expdata[μθ(μϕ(x))x2].L_{\text{Recon}}(\phi,\theta) = \mathbb E_{x\sim p_{\text{data}}} \left[ \left\| \mu_\theta(\mu_\phi(x))-x \right\|^2 \right].

这个目标只要求“压缩后还能还原”。如果只为了压缩和重建,这已经很自然。

但我们这里不是只想做压缩。我们最终想在 latent space 里训练一个生成模型。也就是说,我们希望:

  1. 把真实图像编码成 latent:
z=μϕ(x),xpdata.z=\mu_\phi(x),\qquad x\sim p_{\text{data}}.
  1. 在 latent space 里学习 latent distribution:
platent(z).p_{\text{latent}}(z).
  1. 采样 latent,再 decode 回图像:
zplatent,x=μθ(z).z\sim p_{\text{latent}}, \qquad x=\mu_\theta(z).

问题在于,普通 autoencoder 没有约束 platent(z)p_{\text{latent}}(z) 的形状。它可能很扭曲、很碎、很难学。换句话说,autoencoder 虽然把像素空间压缩了,但可能把原来的数据分布变成了另一个很难建模的 latent distribution。

因此需要一个额外要求:

latent 不仅要能重建图像,还要形成一个适合生成模型学习的分布。

这就是 VAE 要解决的问题。

6.2.2 Variational Autoencoders

VAE 把 deterministic autoencoder 改成 probabilistic autoencoder。

普通 autoencoder 的 encoder 是一个确定函数:

z=μϕ(x).z=\mu_\phi(x).

VAE 的 encoder 变成一个条件分布:

qϕ(zx).q_\phi(z\mid x).

decoder 也变成一个条件分布:

pθ(xz).p_\theta(x\mid z).

最常见的 Gaussian 设定是:

qϕ(zx)=N(z;μϕ(x),diag(σϕ2(x))),q_\phi(z\mid x) = \mathcal N \left( z; \mu_\phi(x), \operatorname{diag}(\sigma_\phi^2(x)) \right), pθ(xz)=N(x;μθ(z),σθ2(z)Id).p_\theta(x\mid z) = \mathcal N \left( x; \mu_\theta(z), \sigma_\theta^2(z)I_d \right).

这时 encoder 不再输出一个 latent 点,而是输出一个 Gaussian distribution 的参数:

μϕ(x),σϕ2(x).\mu_\phi(x),\qquad \sigma_\phi^2(x).

编码时从这个分布里采样:

zqϕ(x).z\sim q_\phi(\cdot\mid x).

解码时也可以理解为从 decoder distribution 采样:

xpθ(z).x\sim p_\theta(\cdot\mid z).

如果方差都退化成 00,VAE 就回到了普通 autoencoder。因此 VAE 不是完全不同的东西,而是 autoencoder 的概率版本。

Reconstruction loss: 现在重建的是概率

普通 autoencoder 用 squared error:

xμθ(μϕ(x))2.\|x-\mu_\theta(\mu_\phi(x))\|^2.

VAE 里 zz 是随机采样的,所以 reconstruction loss 写成 negative log-likelihood:

LVAE-Recon(ϕ,θ)=Expdata,zqϕ(x)[logpθ(xz)].L_{\text{VAE-Recon}}(\phi,\theta) = - \mathbb E_{x\sim p_{\text{data}},\,z\sim q_\phi(\cdot\mid x)} \left[ \log p_\theta(x\mid z) \right].

这句话的意思是:从 xx 编码出一个 latent zz,再问 decoder 给原始 xx 的概率有多高。概率越高,loss 越小。

在 Gaussian decoder 下,代入正态分布密度,可以得到:

LVAE-Recon(ϕ,θ)=Ex,z[12σθ2(z)xμθ(z)2+d2logσθ2(z)]+const.L_{\text{VAE-Recon}}(\phi,\theta) = \mathbb E_{x,z} \left[ \frac{1}{2\sigma_\theta^2(z)} \|x-\mu_\theta(z)\|^2 + \frac d2\log\sigma_\theta^2(z) \right] +\text{const}.

如果把 decoder variance 固定成常数 σ2\sigma^2,这基本就退化成加权 MSE:

LVAE-Recon(ϕ,θ)=Ex,z[12σ2xμθ(z)2]+const.L_{\text{VAE-Recon}}(\phi,\theta) = \mathbb E_{x,z} \left[ \frac{1}{2\sigma^2} \|x-\mu_\theta(z)\|^2 \right] +\text{const}.

所以 VAE 的 reconstruction term 并不神秘。它仍然在要求 decode 后接近原图,只是现在要对 encoder 采样出来的所有可能 latent 做平均。

Prior loss: 让 latent distribution 变得好学

VAE 还会指定一个理想的 latent prior:

pprior(z)=N(0,Ik).p_{\text{prior}}(z)=\mathcal N(0,I_k).

这个 prior 表示:我们希望每个数据点编码出来的 latent distribution 不要乱跑,而是接近标准 Gaussian。于是加入 KL loss:

LVAE-Prior(ϕ)=Expdata[DKL(qϕ(x)pprior)].L_{\text{VAE-Prior}}(\phi) = \mathbb E_{x\sim p_{\text{data}}} \left[ D_{\text{KL}} \left( q_\phi(\cdot\mid x) \| p_{\text{prior}} \right) \right].

KL divergence 衡量两个分布有多不同:

DKL(qp)=q(z)logq(z)p(z)dz.D_{\text{KL}}(q\|p) = \int q(z)\log\frac{q(z)}{p(z)}\,dz.

它满足:

DKL(qp)0,D_{\text{KL}}(q\|p)\ge 0,

并且当且仅当 q=pq=p 时为 00

所以这个 prior loss 的直觉是:

每张图像的 encoder distribution 都应该尽量像 N(0,Ik)\mathcal N(0,I_k),这样整体 latent space 会更规整,更适合后面训练 flow / diffusion model。

最终 VAE 目标是:

LVAE(ϕ,θ)=LVAE-Recon(ϕ,θ)+βLVAE-Prior(ϕ).L_{\text{VAE}}(\phi,\theta) = L_{\text{VAE-Recon}}(\phi,\theta) + \beta L_{\text{VAE-Prior}}(\phi).

也就是:

LVAE(ϕ,θ)=Ex,z[logpθ(xz)]+βEx[DKL(qϕ(x)pprior)].L_{\text{VAE}}(\phi,\theta) = - \mathbb E_{x,z} \left[ \log p_\theta(x\mid z) \right] + \beta \mathbb E_x \left[ D_{\text{KL}} \left( q_\phi(\cdot\mid x) \| p_{\text{prior}} \right) \right].

这里 β\beta 控制两件事的权衡:

Gaussian encoder 下的 KL 项

qϕ(zx)=N(μϕ(x),diag(σϕ2(x))),q_\phi(z\mid x) = \mathcal N \left( \mu_\phi(x), \operatorname{diag}(\sigma_\phi^2(x)) \right),

并且

pprior(z)=N(0,Ik),p_{\text{prior}}(z)=\mathcal N(0,I_k),

则 KL 项有闭式解:

DKL(qϕ(x)N(0,Ik))=12j=1k(μϕ,j2(x)+σϕ,j2(x)logσϕ,j2(x)1).D_{\text{KL}} \left( q_\phi(\cdot\mid x) \| \mathcal N(0,I_k) \right) = \frac12 \sum_{j=1}^k \left( \mu_{\phi,j}^2(x) + \sigma_{\phi,j}^2(x) - \log\sigma_{\phi,j}^2(x) - 1 \right).

这四项也很直观:

因此 KL 项把 encoder distribution 往标准 Gaussian 拉。

Reparameterization trick

VAE 训练还有一个技术点:我们需要从

zqϕ(zx)z\sim q_\phi(z\mid x)

里采样,但这个采样分布依赖参数 ϕ\phi。如果直接采样,梯度不好传。

reparameterization trick 把随机性单独拿出来:

ϵN(0,Ik),\epsilon\sim\mathcal N(0,I_k), z=μϕ(x)+σϕ(x)ϵ.z=\mu_\phi(x)+\sigma_\phi(x)\odot\epsilon.

这样仍然有:

zqϕ(x),z\sim q_\phi(\cdot\mid x),

但随机性来自 ϵ\epsilon,它的分布不依赖 ϕ\phi。而 μϕ(x)\mu_\phi(x)σϕ(x)\sigma_\phi(x) 都在计算图里,梯度可以正常反传。

这一步的直觉是:

不要让“从一个依赖参数的分布里采样”挡住梯度;把采样写成“固定噪声经过可微变换”。

Algorithm 6: β-VAE training

Algorithm 6 把前面的 VAE 训练过程写成 mini-batch 算法:

Algorithm 6: beta-VAE training

每个 batch 中,对每张图像 xix_i

  1. encoder 输出:
μi=μϕ(xi),logσi2=logσϕ2(xi).\mu_i=\mu_\phi(x_i), \qquad \log\sigma_i^2=\log\sigma_\phi^2(x_i).
  1. 采样标准 Gaussian noise:
ϵiN(0,Ik).\epsilon_i\sim\mathcal N(0,I_k).
  1. 用 reparameterization trick 得到 latent:
zi=μi+σiϵi.z_i=\mu_i+\sigma_i\odot\epsilon_i.

其中:

σi=exp(12logσi2).\sigma_i=\exp\left(\frac12\log\sigma_i^2\right).
  1. decoder 输出 reconstruction mean:
x^i=μθ(zi).\hat x_i=\mu_\theta(z_i).
  1. reconstruction loss:
Lrecon=1Bi=1B12σ~2xix^i2.L_{\text{recon}} = \frac1B \sum_{i=1}^B \frac{1}{2\tilde\sigma^2} \|x_i-\hat x_i\|^2.
  1. KL loss:
LKL=1Bi=1B12j=1k(μi,j2+σi,j2logσi,j21).L_{\text{KL}} = \frac1B \sum_{i=1}^B \frac12 \sum_{j=1}^k \left( \mu_{i,j}^2 + \sigma_{i,j}^2 - \log\sigma_{i,j}^2 - 1 \right).
  1. total loss:
L=Lrecon+βLKL.L=L_{\text{recon}}+\beta L_{\text{KL}}.

然后对 encoder 参数 ϕ\phi 和 decoder 参数 θ\theta 做梯度更新。

为什么这对 diffusion / flow 有用

VAE 训练完成后,生成模型不再直接学习像素空间分布:

pdata(x).p_{\text{data}}(x).

而是学习 latent space 里的分布:

platent(z),zqϕ(zx),xpdata.p_{\text{latent}}(z), \qquad z\sim q_\phi(z\mid x), \quad x\sim p_{\text{data}}.

训练 latent flow / diffusion 时,把 latent zz 当成数据点来用:

zRk.z\in\mathbb R^k.

模型学会从简单分布生成 latent:

Z1platent.Z_1\sim p_{\text{latent}}.

最后再用 decoder 回到图像:

x=μθ(Z1).x=\mu_\theta(Z_1).

这就是 latent diffusion / latent flow 的基本结构:

noiselatent generative modelzdecoderx.\text{noise} \longrightarrow \text{latent generative model} \longrightarrow z \longrightarrow \text{decoder} \longrightarrow x.

本节的核心可以收成一句话:

VAE 先把高维图像压缩到规整的 latent space,flow / diffusion model 再在这个更小、更好学的 latent space 里做生成。

6.3 Case Study: Stable Diffusion 3 and Meta Movie Gen

6.3 不再引入新的数学框架,而是把前面几节的部件拼起来,看真实大模型大概怎样落地。

到这里,图像/视频生成模型的主要积木已经有了:

  1. 用 flow matching 或 diffusion 学一个生成过程;
  2. 用 CFG 强化 prompt condition;
  3. 用 text encoder 把 prompt 变成 embedding;
  4. 用 VAE / autoencoder 把像素压到 latent space;
  5. 用 DiT 或 U-Net 参数化
utθ(xy).u_t^\theta(x\mid y).

Stable Diffusion 3 和 Movie Gen Video 的共同点是:它们都不是只靠一个单独技巧,而是把这些部件规模化组合起来。

6.3.1 Stable Diffusion 3

Stable Diffusion 3 是一个 text-to-image model。它的目标可以写成:

xpdata(y),x\sim p_{\text{data}}(\cdot\mid y),

其中 yy 是文本 prompt,xx 是图像。

但它并不直接在像素空间里建模,而是在 autoencoder 的 latent space 里做生成。也就是说,真实图像先经过一个 pretrained autoencoder encoder:

ximagezlatent.x_{\text{image}} \longmapsto z_{\text{latent}}.

生成模型学的是 latent 上的条件分布:

platent(zy).p_{\text{latent}}(z\mid y).

采样结束后,再用 decoder 得到图像:

zximage.z \longmapsto x_{\text{image}}.

因此 Stable Diffusion 3 里真正的 vector field 更准确地说是在 latent space 上:

utθ(zty).u_t^\theta(z_t\mid y).

Stable Diffusion 3 使用的是 conditional flow matching 目标。也就是我们前面学过的逻辑:构造从简单分布到数据分布的 probability path,然后训练网络匹配对应的 velocity field。

它还使用 classifier-free guidance。训练时随机丢掉 prompt condition,使同一个模型同时学会:

utθ(zty)u_t^\theta(z_t\mid y)

utθ(zt).u_t^\theta(z_t\mid \emptyset).

采样时再组合:

u~tθ(zty)=utθ(zt)+w(utθ(zty)utθ(zt)).\tilde u_t^\theta(z_t\mid y) = u_t^\theta(z_t\mid \emptyset) + w \left( u_t^\theta(z_t\mid y) - u_t^\theta(z_t\mid \emptyset) \right).

这里 ww 是 guidance scale。SD3 采样时通常用 Euler simulation,约 50 个 steps,CFG weight 大约在 2.02.05.05.0 之间。

文本条件:为什么要多个 text encoders

Stable Diffusion 3 的一个重要设计是使用多个文本 embedding。它使用三类 text embeddings,包括 CLIP embeddings 和 Google T5-XXL encoder 的 sequential outputs。

可以这样理解:

CLIP embedding 擅长提供整体语义。比如 prompt 是:

a corgi dog wearing sunglasses on the beach

CLIP 会给出一个比较全局的 text-image semantic representation。它告诉模型:这句话整体上描述的是怎样的图像。

T5 这类 transformer text encoder 则更擅长保留 token-level 信息。它不会只给出一个整体向量,而是输出一串 token embeddings:

y=PromptEmbed(yraw)RS×d.y=\operatorname{PromptEmbed}(y_{\text{raw}}) \in \mathbb R^{S\times d}.

这对复杂 prompt 很重要。因为图像里不同区域可能需要关注不同词:

这也是为什么 MM-DiT 需要让 image tokens 和 text tokens 更细粒度地交互。

MM-DiT: Multi-Modal Diffusion Transformer

你给的 Figure 16 对应 MM-DiT 架构:

Figure 16: MM-DiT architecture

图左边是整体流程,右边是一个 MM-DiT block 的结构。

整体输入有三类:

  1. noised latent:
zt;z_t;
  1. timestep:
t;t;
  1. caption / prompt:
yraw.y_{\text{raw}}.

noised latent 先被 patching:

ztz~tRN×d.z_t \longmapsto \tilde z_t\in\mathbb R^{N\times d}.

文本 prompt 经过多个 text encoders,例如 CLIP-G/14、CLIP-L/14、T5-XXL,得到 text tokens:

y~RS×d.\tilde y\in\mathbb R^{S\times d}.

timestep 经过 sinusoidal encoding 和 MLP,得到 time embedding:

t~=TimeEmb(t).\tilde t=\operatorname{TimeEmb}(t).

然后这些信息进入一系列 MM-DiT blocks。

普通 DiT 可以理解为主要更新 image patch tokens;MM-DiT 的关键变化是:文本 token 和图像 token 都进入 transformer 交互。这样 prompt 不是作为一个外部条件轻轻加一下,而是参与每一层的多模态 attention。

一种直观写法是,把 image tokens 和 text tokens 看成两个 token streams:

ziRN×d,yiRS×d.z_i\in\mathbb R^{N\times d}, \qquad y_i\in\mathbb R^{S\times d}.

每一层做联合更新:

(zi+1,yi+1)=MM-DiTBlock(zi,yi,t~).(z_{i+1},y_{i+1}) = \operatorname{MM\text{-}DiTBlock}(z_i,y_i,\tilde t).

这样图像 tokens 可以读文本,文本 tokens 也可以被当前生成状态影响。最后只取 image stream,经过 modulation、linear、unpatching,得到和 latent 同形状的输出:

utθ(zty).u_t^\theta(z_t\mid y).

Figure 16 右边的 block 里,LayerNorm 后面有 Mod,也就是前面 6.1 讲过的 time-conditioned modulation。这里的时间 embedding 仍然在控制每一层如何处理 tokens。

因此 SD3 的结构可以概括为:

text promptCLIP/T5text tokens,\text{text prompt} \xrightarrow{\text{CLIP/T5}} \text{text tokens}, noised latentpatchingimage tokens,\text{noised latent} \xrightarrow{\text{patching}} \text{image tokens}, (text tokens,image tokens,t)MM-DiTutθ(zty).(\text{text tokens},\text{image tokens},t) \xrightarrow{\text{MM-DiT}} u_t^\theta(z_t\mid y).

Stable Diffusion 3 最大模型有约 8B parameters。这个数字本身不是核心,核心是:当模型规模上去以后,architecture 必须能同时处理高维 latent、长 prompt、多模态条件和时间调制。

6.3.2 Meta Movie Gen Video

Movie Gen Video 是 text-to-video model。它和 Stable Diffusion 3 的框架很像,但数据从图像变成视频。

图像 latent 可以写成:

zRC×H×W.z\in\mathbb R^{C\times H\times W}.

视频多了时间维度:

xRT×C×H×W.x\in\mathbb R^{T\times C\times H\times W}.

这里 TT 是帧数。这个维度让问题明显更贵,因为模型不仅要生成每一帧,还要保持帧与帧之间的时间一致性。

Movie Gen Video 同样使用 conditional flow matching,并采用 straight-line Gaussian path:

αt=t,σt=1t.\alpha_t=t, \qquad \sigma_t=1-t.

也就是我们前面熟悉的:

xt=tz+(1t)ϵ.x_t=t z+(1-t)\epsilon.

对应 target velocity 是:

uttarget(xtz)=zϵ.u_t^{\text{target}}(x_t\mid z) = z-\epsilon.

不过这里的 zz 不是单张图像,而是视频 latent。

视频为什么更需要 latent space

视频如果直接在像素空间建模,维度是:

T×C×H×W.T\times C\times H\times W.

比图像还要大很多。因此 Movie Gen Video 使用 frozen pretrained temporal autoencoder。它把原始视频:

xRT×3×H×Wx'\in\mathbb R^{T'\times 3\times H'\times W'}

压缩成 latent video:

xRT×C×H×W.x\in\mathbb R^{T\times C\times H\times W}.

压缩比例大致满足:

TT=HH=WW=8.\frac{T'}{T} = \frac{H'}{H} = \frac{W'}{W} = 8.

这表示 autoencoder 不只在空间上压缩 H,WH,W,也在时间维度上压缩 TT。这就是 temporal autoencoder 的意义:它不是逐帧压缩图片,而是把视频作为带时间结构的对象来压缩。

对于长视频,还可以使用 temporal tiling:把视频切成多个时间片段,分别 encode,再把 latent 拼起来。这样可以降低显存压力。

Movie Gen 的 DiT-like backbone

Movie Gen 的生成模型仍然是 DiT-like backbone。但这里的 patch 不只是二维图像 patch,而是同时沿时间和空间 patchify。

可以把视频 latent 看成一个三维网格:

T×H×W.T\times H\times W.

patchify 后得到一串 spatio-temporal tokens。每个 token 对应某一段时间和某一块空间区域。

Transformer 需要建模两类关系:

  1. 空间关系:同一帧内不同区域如何协调;
  2. 时间关系:不同帧之间如何保持运动连续。

文本 conditioning 也类似 SD3,通过 text embeddings 和 cross-attention 注入。Movie Gen 使用 UL2、ByT5、MetaCLIP 三类 embeddings:

最大 Movie Gen Video 模型约 30B parameters。这里的规模比 SD3 更大,并不奇怪,因为视频生成比图像生成多了时间维度,需要模型同时解决画面质量和运动一致性。

6.3 的总结构

Stable Diffusion 3 和 Movie Gen Video 都可以放回同一个框架里:

raw dataautoencoderlatent data,\text{raw data} \xrightarrow{\text{autoencoder}} \text{latent data}, prompttext encoderstext embeddings,\text{prompt} \xrightarrow{\text{text encoders}} \text{text embeddings}, (noised latent,t,text embeddings)DiT / MM-DiTutθ(y),(\text{noised latent},t,\text{text embeddings}) \xrightarrow{\text{DiT / MM-DiT}} u_t^\theta(\cdot\mid y), simulate ODE/SDEgenerated latentdecoderimage or video.\text{simulate ODE/SDE} \longrightarrow \text{generated latent} \xrightarrow{\text{decoder}} \text{image or video}.

第 6 节的总要点是:大规模生成模型并没有换掉前面学过的 flow matching / diffusion 原理;它们主要是在三个地方做工程和架构放大:

因此可以把这一节收成一句话:

Stable Diffusion 3 和 Movie Gen Video 是前面数学框架的规模化实现:在 latent space 里,用带时间调制和文本条件的 transformer 学习 conditional vector field。


Edit page
Share this post on:

Next Post
从 Transformer 到 Decision Transformer:VLA 前置知识一文梳理