Flow Matching for Generative Modeling
基本信息
标题:
作者:
机构:
时间:
预印时间: 20??.??.?? ArXiv v1
更新笔记: 20??.??.??
发表:
链接:
标签:
页数: ?
引用: ?
被引: ?
数据:
对比:
复现:
设 $\mathbb{R}^d$ 为数据空间, 数据点 $x=(x^1,\cdots,x^d)\in\mathbb{R}^d$ .
通篇需要两个重要的概念:
概率密度路径 (Probability Density Path) $p=p_t(x): [0,1]\times \mathbb{R}^d\to \mathbb{R}_{>0}$ , 是依赖时间的概率密度函数, 即 $\int p_t(x)\text{d}x = 1$ ;
依赖时间的矢量场 (Time-Dependent Vector Field) $v=v_t: [0,1]\times \mathbb{R}^d\to \mathbb{R}^d$ .
矢量场 $v_t$ 可以用于构造依赖时间的微分同胚映射 (Diffeomprphic Map) , 称为流 (Flow) $\phi: [0,1]\times \mathbb{R}^d\to \mathbb{R}^d$ , 通过常微分方程 ODE 定义如下:
$$
\dfrac{\text{d}}{\text{d} t}\phi_t(x) = v_t(\phi_t(x)),\quad \phi_0(x)=x.
$$
在 NeurIPS 2018 发表的论文 "Neural Ordinary Differential Equations" 中, 使用了神经网络来建模矢量场 $v_t(x;\theta)\approx v_t$ , 其中 $\theta\in \mathbb{R}^p$ 是可学习的参数, 这反而得到了一种流 $\phi_t$ 的深度参数模型, 称为连续标准化流 (Continuous Normalizing Flow, CNF) . CNF 用于将一个简单的先验分布密度 $p_0$ (如纯噪声) 通过前推方程 (Push-Forward Equation) 重塑为一个更复杂的分布密度 $p_1$ .
$$
p_t = [\phi_t]_{*}p_0
$$
其中前推或变量替换算子 $*$ 定义为
$$
[\phi_t]_{*}p_0(x) = p_0(\phi^{-1}_t(x))\det \left[ \dfrac{\partial \phi_t^{-1}}{\partial x}(x)\right].
$$
当流 $\phi_t$ 满足前推方程 时, 则称一个矢量场 $v_t$ 生成一个概率密度路径 $p_t$ .
一个测试矢量场是否生成概率路径的实用方法是使用连续性方程 (Continuity Equation) , 这也是证明过程中的一个关键部分.
连续性方程 是一个偏微分方程, 是矢量场 $v_t$ 生成概率路径 $p_t$ 的充分必要条件,
$$
\dfrac{\text{d}}{\text{d} t}p_t(x)+\text{div}(p_t(x)v_t(x)) = 0.
$$
此外, 对 CNF 的一些信息翻新, 特别是如何计算在任一点 $x\in\mathbb{R}^d$ 处的概率 $p_1(x)$ .
随机变量 $x_1$ 服从某些未知数据分布 $q(x_1)$ . 我们只能使用从 $q(x_1)$ 中采样得到的数据样本, 而不能直接使用密度函数本身.
我们设 $p_t$ 为概率路径, 其初值 $p_0=p$ 是一个简单分布, 例如标准正态分布 $p(x)=\mathcal{N}(x|0,I)$ , 而 $p_1$ 近似等价于分布 $q$ .
后续将讨论如何构造这样的路径.
Flow Matching 的目标是匹配这样的目标概率路径, 使得我们能够从 $p_0$ 流到 $p_1$ .
给定一个目标概率密度路径 $p_t(x)$ 和生成该路径 $p_t(x)$ 的相应矢量场 $u_t(x)$ .
我们定义 Flow Matching 目标函数为:
$$
\mathcal{L}{FM}(\theta) = \mathbb{E} {t,p_t(x)}| v_t(x)-u_t(x) |^2.
$$
其中 $\theta$ 是 CNF 矢量场 $v_t$ 的可学习参数, $t\sim \mathcal{U}[0,1]$ , $x\sim p_t(x)$ .
该损失函数使用神经网络 $v_t$ 对矢量场 $u_t$ 进行回归. 当损失值达到零, 学习到的 CNF 模型将能够生成 $p_t(x)$ .
虽然 Flow Matching 是一种简单且有吸引力的目标函数, 但是实际上难以处理. 这是因为我们缺乏 $p_t$ 和 $u_t$ 的先验知识, 且可能存在很多种可能的概率路径满足 $p_1(x)\approx q(x)$ . 更重要的是, 我们通常无法获得能够生成我们想要的 $p_t$ 的 $u_t$ 的具体形式.
解决方案是我们可以使用仅在每个样本上定义的概率路径和矢量场, 再结合合适的聚合方法来构造我们想要的 $p_t$ 和 $u_t$ .
除此之外, 这一构造允许我们构建出更好处理的 Flow Matching 目标函数.
构造目标概率路径的一种简单方式是通过更简单的概率路径的混合:
给定一个特定点 $x_1$ , 我们定义条件概率路径 (Conditional Probability Path) $p_t(x|x_1)$ 使得 $t=0$ 时 $p_0(x|x_1)=p(x)$ , 在 $t=1$ 时 $p_1(x|x_1)$ 为以 $x=x_1$ 为中心的分布, 如正态分布 $p_1(x|x_1)=\mathcal{N}(x|x_1,\sigma^2I)$ .
在 $q(x_1)$ 上对条件概率路径边际化得到边际概率路径 (Marginal Probability Path) $p_t(x)$ , 即
$$
p_t(x) = \int p_t(x|x_1)q(x_1)\text{d}x_1.
$$
当 $t=1$ 时, 边际分布 $p_1$ 是个近似数据分布 $q$ 的混合分布,
$$
p_1(x) = \int p_1(x|x_1)q(x_1)\text{d}x_1 \approx q(x)
$$
类似地, 可以定义边际矢量场 (Marginal Vector Field) , 通过在如下意义下对条件矢量场进行边际化:
$$
u_t(x) = \int u_t(x|x_1) \dfrac{p_t(x|x_1)q(x_1)}{p_t(x)}\text{d}x_1.
$$
其中 $u_t(\cdot|x_1): \mathbb{R}^d\to \mathbb{R}^d$ 是条件矢量场, 能生成条件概率路径 $p_t(\cdot|x_1)$ .
这可能不明显, 但这种聚合条件矢量场的方式切实地生成了正确的矢量场, 用于建模边际概率路径.
一个关键的观察是: 边际矢量场生成边际概率路径 .
这一观察表明了条件矢量场 (可生成条件概率路径) 和边际矢量场 (可生成边际概率路径) 之间存在的联系.
这一连续允许我们将未知且难以处理的边际矢量场拆分为更简单的条件矢量场, 这些条件矢量场仅依赖于单个数据样本进行定义.
下面给出更正式的定理:
定理 1. 给定能生成条件概率路径 $p_t(x|x_1)$ 的矢量场 $u_t(x|x_1)$ , 对于任意分布 $q(x_1)$ , 边际矢量场 $u_t$ 可以生成边际概率路径 $p_t$ , 即 $u_t$ 和 $p_t$ 满足连续性方程.
这一定理还可以由 Peluchetti (2021) 中的扩散混合表示定理 (Diffusion Mixture Representation Theorem) 得到, 提供了在扩散随机微分方程中的边际漂移系数和扩散系数.
证明 : 下面验证 $u_t$ 和 $p_t$ 满足连续性方程.
$$
\begin{aligned}
\dfrac{\text{d}}{\text{d} t} p_t(x)
&= \dfrac{\text{d}}{\text{d} t} \int p_t(x|x_1)q(x_1)\text{d}x_1\\
&= \int \dfrac{\text{d} p_t(x|x_1)}{\text{d} t}q(x_1)\text{d}x_1\\
&= \int -\text{div}(p_t(x|x_1)u_t(x|x_1)) q(x_1)\text{d}x_1\\
&= -\text{div}(\int p_t(x|x_1)u_t(x|x_1)q(x_1)\text{d}x_1)\\
&= -\text{div}(\int u_t(x|x_1)\dfrac{p_t(x|x_1) q(x_1)}{p_t(x)} p_t(x)\text{d}x_1)\\
&= -\text{div}(u_t(x)p_t(x))\\
\end{aligned}
$$
中间使用了条件矢量场 $u_t(\cdot|x_1)$ 生成条件概率路径 $p_t(\cdot|x_1)$ , 即它们满足连续性方程.
最后代入 $u_t(x)$ 的定义, 得到结论.
注: 求导和积分可交换 (假设被积函数满足 Leibniz Rule 的正则化条件).
不幸的是, 由于边际概率路径和矢量场的定义中存在难以处理的积分, 所以仍然难以计算 $u_t$ , 从而无法直接计算原始 Flow Matching 目标函数的无偏估计.
因此, 我们使用更简单的目标函数, 可以获得和原始目标函数相同的最优条件.
具体地, 我们考虑 条件流匹配 (Conditional Flow Matching, CFM) 目标函数, 它定义为:
$$
\mathcal{L}{CFM}(\theta) = \mathbb{E} {t,q(x_1),p_t(x|x_1)} | v_t(x) - u_t(x|x_1)|^2.
$$
其中 $t\sim \mathcal{U}[0,1]$ , $x_1\sim q(x_1)$ , $x\sim p_t(x|x_1)$ .
和 FM 目标函数不同, CFM 目标函数允许我们简单地采样无偏估计, 这只需要我们能够从 $p_t(x|x_1)$ 中高效采样, 并计算 $u_t(x|x_1)$ . 这两点都很容易做到, 因为它们定义在单个样本的基上.
由此我们得到第二个关键的观察: 流匹配目标函数和条件流匹配目标函数具有相同的参数梯度 .
即优化 CFM 目标函数在期望上等价于优化 FM 目标函数.
因此, 这允许我们训练一个 CNF 用于生成边际概率路径 $p_t$ , 在 $t=1$ 时近似未知数据分布 $q$ , 且无需边际概率路径或边际矢量场.
我们只需要设计合适的条件概率路径和矢量场.
下面给出更正式的定理.
定理 2. 假设对于 $x\in \mathbb{R}^d$ 和 $t\in [0,1]$ 有 $p_t(x)>0$ , $\mathcal{L}{CFM}$ 和 $\mathcal{L} {FM}$ 最多相差一个与 $\theta$ 无关的常数, 即 $\nabla_\theta \mathcal{L}{CFM}(\theta) = \nabla \theta \mathcal{L}_{FM}(\theta)$.
证明 : 为了确保所有积分的存在性, 且允许积分顺序交换 (Fubini's Theorem), 需要假设 $q(x)$ 和 $p_t(x|x_1)$ 随着 $|x| \to\infty$ 以足够的速度下降到零, 且 $u_t$ , $v_t$ 和 $\nabla_\theta v_t$ 都是有界的.
首先使用二范数的标准双线性, 我们有
$$
| v_t(x) - u_t(x) |^2 = |v_t(x)|^2 - 2\langle v_t(x), u_t(x) \rangle + |u_t(x)|^2.
$$
$$
| v_t(x) - u_t(x|x_1) |^2 = |v_t(x)|^2 - 2\langle v_t(x), u_t(x|x_1) \rangle + |u_t(x|x_1)|^2.
$$
$u_t$ 和 $\theta$ 无关, 且注意到
计算第一项:
$$
\begin{aligned}
\mathbb{E}{p_t(x)}| v_t(x) |^2
&= \int |v_t(x)|^2 p_t(x)\text{d}x \
&= \int |v_t(x)|^2 \int p_t(x|x_1)q(x_1)\text{d}x_1 \text{d}x\
&= \mathbb{E} {q(x_1), p_t(x|x_1)} |v_t(x)|^2
\end{aligned}
$$
然后计算中间项:
$$
\begin{aligned}
\mathbb{E}{p_t(x)}\langle v_t(x), u_t(x)\rangle
&= \int \langle v_t(x), u_t(x) \rangle p_t(x)\text{d}x \
&= \int \langle v_t(x), \int u_t(x|x_1) \dfrac{p_t(x|x_1)q(x_1)}{p_t(x)}\text{d}x_1 \rangle p_t(x)\text{d}x \
&= \int \langle v_t(x), \int u_t(x|x_1) p_t(x|x_1)q(x_1) \text{d}x_1 \rangle\text{d}x \
&= \int\int \langle v_t(x), u_t(x|x_1)\rangle p_t(x|x_1)q(x_1) \text{d}x_1 \text{d}x \
&= \mathbb{E} {q(x_1), p_t(x|x_1)}\langle v_t(x), u_t(x|x_1)\rangle
\end{aligned}
$$
可以得到两个损失函数只差了一个和 $\theta$ 无关的常数, 所以梯度相同.
CFM 目标损失可以应用于任意的条件概率路径和条件矢量场.
这部分讨论基于高斯条件概率路径构造相应的 $p_t(x|x_1)$ 和 $u_t(x|x_1)$ .
具体地, 我们考虑如下形式的条件概率路径:
$$
p_t(x|x_1) = \mathcal{N}(x|\mu_t(x_1),\sigma_t(x_1)^2 I)
$$
其中 $\mu_t$ 是高斯分布的依赖于时间的均值, $\sigma_t$ 是高斯分布的依赖于时间的标量标准差.
当 $t=0$ 时均值为 0, 标准差为 1, 因此所有条件概率路径将在 $t=0$ 时收敛到相同的标准高斯噪声分布 $p(x)=\mathcal{N}(x|0,I)$ .
当 $t=1$ 时均值为 $x_1$ , 标准差为足够小的 $\sigma_{\min}$ , 使得 $p_1(x|x_1)$ 是中心在 $x_1$ 的中心化高斯分布.
存在无限个矢量场能够生成任意特定的概率路径 (例如通过向连续性方程添加散度无关的项), 然而这些向量场的绝大多数是由于存在使得底层分布不变的分量存在. 例如当分布是旋转不变的旋转分量, 这导致了不必要的额外计算.
因此我们决定使用最简单的矢量场, 对应高斯分布的典型变换.
具体地, 考虑以 $x_1$ 为条件的流:
$$
\psi_t(x) = \sigma_t(x_1) x +\mu_t(x_1)
$$
其中 $x$ 服从标准高斯分布, $\psi_t(x)$ 是仿射变换, 映射到具有均值 $\mu_t(x_1)$ 和标准差 $\sigma_t(x_1)$ 的高斯分布. 也就是说, 根据前推方程, $\psi_t$ 将噪声分布 $p_0(x|x_1)=p(x)$ 推向 $p_t(x|x_1)$ , 即
$$
[\psi_t]_{*}p(x) = p_t(x|x_1)
$$
这个流提供了一个矢量场用于生成条件概率路径:
$$
\dfrac{\text{d}}{\text{d}t}\psi_t (x) = u_t(\psi_t(x)|x_1)
$$
将 $p_t(x|x_1)$ 重参数化, 然后将上式插入到 CFM 损失, 得到
$$
\mathcal{L}{CFM}(\theta) =\mathbb{E} {t,q(x_1),p(x_0)} | v_t(\psi_t(x_0)) -\dfrac{\text{d}}{\text{d} t}\psi_t(x_0)|^2.
$$
因为 $\psi_t$ 是简单的可逆仿射映射, 所以可以使用上述的微分方程用于求解 $u_t$ 的解析形式.
用 $f'$ 表示时间依赖函数 $f$ 关于时间的导数, 即 $f'=\dfrac{\text{d} f}{\text{d}t}$ .
定理 3. $p_t(x|x_1)$ 为高斯概率路径 $\mathcal{N}(x|\mu_t(x_1),\sigma_t(x_1)^2 I)$ , 对应的流映射为 $\psi_t=\sigma_t(x_1) x +\mu_t(x_1)$ . 存在唯一的矢量场能够定义这个流 $\psi_t$ , 其形式为:
$$
u_t(x|x_1) = \dfrac{\sigma'_t(x_1)}{\sigma_t(x_1)}(x-\mu_t(x_1)) + \mu'_t(x_1)
$$
$u_t(x|x_1)$ 生成高斯概率路径 $p_t(x|x_1)$ .
我们的形式化对于任意的函数 $\mu_t(x_1)$ 和 $\sigma_t(x_1)$ 都适用, 我们可以将它们设置为任何满足期望边界条件的可微函数. 首先讨论一些特殊情形.
扩散模型从数据点开始逐渐添加噪声直到它近似纯噪声. 这可以形式化为随机过程, 它有着严格的要求以获得在任意时间的封闭形式表示, 得到具有特定均值和标准差的高斯条件概率路径 $p_t(x|x_1)$ .
例如反向 (噪声到数据) Variance Exploding, VE 路径具有如下形式
$$
p_t(x|x_1) = \mathcal{N}(x|x_1, \sigma_{1-t}^2 I)
$$
其中 $\sigma_t$ 是一个递增函数, 初值为 0, 终值远大于 1.
那么上式提供了均值为 $\mu_t(x_1)=x_1$ 和 $\sigma_t(x_1)=\sigma_{1-t}$ 的选择.
将这两个值代入到定理 3 的 $u_t$ 中, 得到
$$
u_t(x|x_1) = -\dfrac{\sigma'{1-t}}{\sigma {1-t}} (x-x_1)
$$
反向 (噪声到数据) Variance Preserving, VP 扩散路径具有如下形式
$$
p_t(x|x_1) = \mathcal{N}(x|\alpha_{1-t} x_1, (1-\alpha_{1-t}^2)I)
$$
其中 $\alpha_t = \exp(-\dfrac{1}{2}T(t))$ , $T(t)=\int_0^t \beta(s)\text{d}s$ , $\beta$ 是噪声缩放函数.
那么上式提供了均值为 $\mu_t(x_1)=\alpha_{1-t}x_1$ 和 $\sigma_t(x_1)=\sqrt(1-\alpha_{1-t}^2)$ .
将这两个值代入到定理 3 的 $u_t$ 中, 得到
$$
u_t(x|x_1) = -\dfrac{\alpha'{1-t}}{1-\alpha {1-t}^2} (\alpha_{1-t}x-x_1) = \dfrac{-T(1-t)}{2}\left[\dfrac{\exp(-T(1-t))x-\exp(-\dfrac{1}{2}T(1-t))x_1}{1-\exp(-T(1-t))}\right]
$$
我们构造的条件矢量场 $u_t(x|x_1)$ 和之前在确定性概率流中使用的矢量场一致.
但是将扩散条件矢量场和 Flow Matching 目标函数结合能提供更好的训练替代, 相比现有的分数匹配方法, 在实验中更稳定和稳健.