深度学习(生成式模型)——Classifier Guidance Diffusion

编程入门 行业动态 更新时间:2024-10-22 09:29:49

<a href=https://www.elefans.com/category/jswz/34/1769690.html style=深度学习(生成式模型)——Classifier Guidance Diffusion"/>

深度学习(生成式模型)——Classifier Guidance Diffusion

文章目录

  • 前言
  • 问题建模
  • 条件扩散模型的前向过程
  • 条件扩散模型的反向过程
  • 条件扩散模型的训练目标

前言

几乎所有的生成式模型,发展到后期都需要引入"控制"的概念,可控制的生成式模型才能更好应用于实际场景。本文将总结《Diffusion Models Beat GANs on Image Synthesis》中提出的Classifier Guidance Diffusion(即条件扩散模型),其往Diffusion Model中引入了控制的概念,可以控制DDPM、DDIM生成指定类别(条件)的图片。

问题建模

本章节所有符号定义与DDPM一致,在条件 y y y下的Diffusion Model的前向与反向过程可以定义为
q ^ ( x t + 1 ∣ x t , y ) q ^ ( x t ∣ x t + 1 , y ) \begin{aligned} \hat q(x_{t+1}|x_{t},y)\\ \hat q(x_t|x_{t+1},y) \end{aligned} q^​(xt+1​∣xt​,y)q^​(xt​∣xt+1​,y)​
只要求出上述两个概率密度函数,我们即可按条件生成图像。

我们利用 q ^ \hat q q^​表示条件扩散模型的概率密度函数, q q q表示扩散模型的概率密度函数。

条件扩散模型的前向过程

对于前向过程,作者定义了以下等式
q ^ ( x 0 ) = q ( x 0 ) q ^ ( x t + 1 ∣ x t , y ) = q ( x t + 1 ∣ x t ) q ^ ( x 1 : T ∣ x 0 , y ) = ∏ t = 1 T q ^ ( x t ∣ x t − 1 , y ) \begin{aligned} \hat q(x_0)&=q(x_0)\\ \hat q(x_{t+1}|x_t,y)&=q(x_{t+1}|x_t)\\ \hat q(x_{1:T}|x_0,y)&=\prod_{t=1}^T\hat q(x_t|x_{t-1},y) \end{aligned} q^​(x0​)q^​(xt+1​∣xt​,y)q^​(x1:T​∣x0​,y)​=q(x0​)=q(xt+1​∣xt​)=t=1∏T​q^​(xt​∣xt−1​,y)​

基于上述第二行定义,可知基于条件 y y y的diffusion model的前向过程与普通的diffusion model一致,即 q ^ ( x t + 1 ∣ x t ) = q ( x t + 1 ∣ x t ) \hat q(x_{t+1}|x_t)=q(x_{t+1}|x_t) q^​(xt+1​∣xt​)=q(xt+1​∣xt​)。即加噪过程与条件 y y y无关,这种定义也是合理的。

条件扩散模型的反向过程

对于反向过程,我们有
q ^ ( x t ∣ x t + 1 , y ) = q ^ ( x t , x t + 1 , y ) q ^ ( x t + 1 , y ) = q ^ ( x t , x t + 1 , y ) q ^ ( y ∣ x t + 1 ) q ^ ( x t + 1 ) = q ^ ( x t , y ∣ x t + 1 ) q ^ ( y ∣ x t + 1 ) = q ^ ( y ∣ x t , x t + 1 ) q ^ ( x t ∣ x t + 1 ) q ^ ( y ∣ x t + 1 ) (1.0) \begin{aligned} \hat q(x_t|x_{t+1},y)&=\frac{\hat q(x_t,x_{t+1},y)}{\hat q(x_{t+1},y)}\\ &=\frac{\hat q(x_t,x_{t+1},y)}{\hat q(y|x_{t+1})\hat q(x_{t+1})}\\ &=\frac{\hat q(x_t,y|x_{t+1})}{\hat q(y|x_{t+1})}\\ &=\frac{\hat q(y|x_t,x_{t+1})\hat q(x_{t}|x_{t+1})}{\hat q(y|x_{t+1})} \end{aligned}\tag{1.0} q^​(xt​∣xt+1​,y)​=q^​(xt+1​,y)q^​(xt​,xt+1​,y)​=q^​(y∣xt+1​)q^​(xt+1​)q^​(xt​,xt+1​,y)​=q^​(y∣xt+1​)q^​(xt​,y∣xt+1​)​=q^​(y∣xt+1​)q^​(y∣xt​,xt+1​)q^​(xt​∣xt+1​)​​(1.0)

已知条件扩散模型的前向过程与扩散模型一致,则有

q ^ ( x 1 : T ∣ x 0 ) = q ( x 1 : T ∣ x 0 ) \hat q(x_{1:T}|x_0)=q(x_{1:T}|x_0) q^​(x1:T​∣x0​)=q(x1:T​∣x0​)

进而有
q ^ ( x t ) = ∫ q ^ ( x 0 , . . . , x t ) d x 0 : t − 1 = ∫ q ^ ( x 0 ) q ^ ( x 1 : t ∣ x 0 ) d x 0 : t − 1 = ∫ q ( x 0 ) q ( x 1 : t ∣ x 0 ) d x 0 : t − 1 = q ( x t ) \begin{aligned} \hat q(x_{t})&=\int \hat q(x_0,...,x_t) dx_{0:t-1}\\ &=\int \hat q(x_0)\hat q(x_{1:t}|x_0)dx_{0:t-1}\\ &=\int q(x_0)q(x_{1:t}|x_0)dx_{0:t-1}\\ &=q(x_t) \end{aligned} q^​(xt​)​=∫q^​(x0​,...,xt​)dx0:t−1​=∫q^​(x0​)q^​(x1:t​∣x0​)dx0:t−1​=∫q(x0​)q(x1:t​∣x0​)dx0:t−1​=q(xt​)​

对于 q ^ ( x t ∣ x t + 1 ) \hat q(x_t|x_{t+1}) q^​(xt​∣xt+1​),则有
q ^ ( x t ∣ x t + 1 ) = q ^ ( x t , x t + 1 ) q ^ ( x t + 1 ) = q ^ ( x t + 1 ∣ x t ) q ^ ( x t ) q ^ ( x t + 1 ) = q ( x t + 1 ∣ x t ) q ( x t ) q ( x t + 1 ) = q ( x t ∣ x t + 1 ) \begin{aligned} \hat q(x_t|x_{t+1})&=\frac{\hat q(x_t,x_{t+1})}{\hat q(x_{t+1})}\\ &=\frac{\hat q(x_{t+1}|x_t)\hat q(x_{t})}{\hat q(x_{t+1})}\\ &=\frac{q(x_{t+1}|x_t)q(x_{t})}{q(x_{t+1})}\\ &=q(x_{t}|x_{t+1}) \end{aligned} q^​(xt​∣xt+1​)​=q^​(xt+1​)q^​(xt​,xt+1​)​=q^​(xt+1​)q^​(xt+1​∣xt​)q^​(xt​)​=q(xt+1​)q(xt+1​∣xt​)q(xt​)​=q(xt​∣xt+1​)​

对于 q ^ ( y ∣ x t , x x t + 1 ) \hat q(y|x_t,x_{x_{t+1}}) q^​(y∣xt​,xxt+1​​),我们有
q ^ ( y ∣ x t , x x t + 1 ) = q ^ ( x t + 1 ∣ x t , y ) q ^ ( y ∣ x t ) q ^ ( x t + 1 ∣ x t ) = q ^ ( x t + 1 ∣ x t ) q ^ ( y ∣ x t ) q ^ ( x t + 1 ∣ x t ) = q ^ ( y ∣ x t ) \begin{aligned} \hat q(y|x_t,x_{x_{t+1}})&=\frac{\hat q(x_{t+1}|x_t,y)\hat q(y|x_t)}{\hat q(x_{t+1}|x_t)}\\ &=\frac{\hat q(x_{t+1}|x_t)\hat q(y|x_t)}{\hat q(x_{t+1}|x_t)}\\ &=\hat q(y|x_t) \end{aligned} q^​(y∣xt​,xxt+1​​)​=q^​(xt+1​∣xt​)q^​(xt+1​∣xt​,y)q^​(y∣xt​)​=q^​(xt+1​∣xt​)q^​(xt+1​∣xt​)q^​(y∣xt​)​=q^​(y∣xt​)​

因此式1.0为

q ^ ( x t ∣ x t + 1 , y ) = q ^ ( y ∣ x t , x t + 1 ) q ^ ( x t ∣ x t + 1 ) q ^ ( y ∣ x t + 1 ) = q ^ ( y ∣ x t ) q ( x t ∣ x t + 1 ) q ^ ( y ∣ x t + 1 ) \begin{aligned} \hat q(x_t|x_{t+1},y)&=\frac{\hat q(y|x_t,x_{t+1})\hat q(x_{t}|x_{t+1})}{\hat q(y|x_{t+1})}\\ &=\frac{\hat q(y|x_t)q(x_{t}|x_{t+1})}{\hat q(y|x_{t+1})} \end{aligned} q^​(xt​∣xt+1​,y)​=q^​(y∣xt+1​)q^​(y∣xt​,xt+1​)q^​(xt​∣xt+1​)​=q^​(y∣xt+1​)q^​(y∣xt​)q(xt​∣xt+1​)​​

由于在反向过程中, x t + 1 x_{t+1} xt+1​是已知的,因此 q ^ ( y ∣ x t + 1 ) \hat q(y|x_{t+1}) q^​(y∣xt+1​)也可看成已知值,设其倒数为 Z Z Z,则有

q ^ ( x t ∣ x t + 1 , y ) = Z q ^ ( y ∣ x t ) q ( x t ∣ x t + 1 ) \begin{aligned} \hat q(x_t|x_{t+1},y) = Z\hat q(y|x_t)q(x_{t}|x_{t+1}) \end{aligned} q^​(xt​∣xt+1​,y)=Zq^​(y∣xt​)q(xt​∣xt+1​)​

取log可得
log ⁡ q ^ ( x t ∣ x t + 1 , y ) = log ⁡ Z + log ⁡ q ^ ( y ∣ x t ) + log ⁡ q ^ ( x t ∣ x t + 1 ) (1.1) \begin{aligned} \log \hat q(x_{t}|x_{t+1},y)=\log Z+\log \hat q(y|x_t)+\log \hat q(x_t|x_{t+1})\tag{1.1} \end{aligned} logq^​(xt​∣xt+1​,y)=logZ+logq^​(y∣xt​)+logq^​(xt​∣xt+1​)​(1.1)

设 q ^ ( x t ∣ x t + 1 ) = N ( μ t , ∑ t 2 ) \hat q(x_t|x_{t+1})=\mathcal N(\mu_t,\sum_t^2) q^​(xt​∣xt+1​)=N(μt​,∑t2​),则有
log ⁡ q ^ ( x t ∣ x t + 1 ) = − 1 2 ( x t − μ t ) T ( ∑ t ) − 1 ( x t − μ t ) + C (1.2) \log \hat q(x_{t}|x_{t+1})=-\frac{1}{2}(x_t-\mu_t)^T({\sum}_t)^{-1}(x_t-\mu_t)+C\tag{1.2} logq^​(xt​∣xt+1​)=−21​(xt​−μt​)T(∑t​)−1(xt​−μt​)+C(1.2)

对于 log ⁡ q ^ ( y ∣ x t ) \log \hat q(y|x_t) logq^​(y∣xt​),在 x t = μ t x_t=\mu_t xt​=μt​处做泰勒展开,则有

log ⁡ q ^ ( y ∣ x t ) ≈ log ⁡ q ^ ( y ∣ x t ) ∣ x t = μ t + ( x t − μ t ) ∇ x t log ⁡ q ^ ( y ∣ x t ) ∣ x t = μ t = C 1 + ( x t − μ t ) g (1.3) \begin{aligned} \log \hat q(y|x_t) &\approx \log \hat q(y|x_t)|_{x_t=\mu_t}+(x_t-\mu_t)\nabla_{x_t}\log\hat q(y|x_t)|_{x_t=\mu_t}\\ &=C_1+(x_t-\mu_t)g \end{aligned}\tag{1.3} logq^​(y∣xt​)​≈logq^​(y∣xt​)∣xt​=μt​​+(xt​−μt​)∇xt​​logq^​(y∣xt​)∣xt​=μt​​=C1​+(xt​−μt​)g​(1.3)
其中 g = ∇ x t log ⁡ q ^ ( y ∣ x t ) ∣ x t = μ t g=\nabla_{x_t}\log\hat q(y|x_t)|_{x_t=\mu_t} g=∇xt​​logq^​(y∣xt​)∣xt​=μt​​,结合式1.1、1.2、1.3,有

log ⁡ q ^ ( x t ∣ x t + 1 , y ) ≈ C 1 + ( x t − μ t ) g + log ⁡ Z − 1 2 ( x t − μ t ) T ( ∑ t ) − 1 ( x t − μ t ) + C = ( x t − μ t ) g − 1 2 ( x t − μ t ) T ( ∑ t ) − 1 ( x t − μ t ) + C 2 = − 1 2 ( x t − μ t − ∑ t g ) T ( ∑ t ) − 1 ( x t − μ t − ∑ t g ) + C 3 \begin{aligned} \log \hat q(x_{t}|x_{t+1},y)&\approx C_1+(x_t-\mu_t)g+\log Z-\frac{1}{2}(x_t-\mu_t)^T(\sum{_t})^{-1}(x_t-\mu_t)+C\\ &=(x_t-\mu_t)g-\frac{1}{2}(x_t-\mu_t)^T(\sum{_t})^{-1}(x_t-\mu_t)+C_2\\ &=-\frac{1}{2}(x_t-\mu_t-\sum{_t} g)^T(\sum{_t})^{-1}(x_t-\mu_t-\sum{_t}g)+C_3 \end{aligned} logq^​(xt​∣xt+1​,y)​≈C1​+(xt​−μt​)g+logZ−21​(xt​−μt​)T(∑t​)−1(xt​−μt​)+C=(xt​−μt​)g−21​(xt​−μt​)T(∑t​)−1(xt​−μt​)+C2​=−21​(xt​−μt​−∑t​g)T(∑t​)−1(xt​−μt​−∑t​g)+C3​​

最终有

q ^ ( x t ∣ x t + 1 , y ) ≈ N ( μ t + ∑ t g , ( ∑ t ) 2 ) g = ∇ x t log ⁡ q ^ ( y ∣ x t ) ∣ x t = μ t (1.4) \begin{aligned} \hat q(x_t|x_{t+1},y)\approx \mathcal N(\mu_t+{\sum}_{t}g,({\sum}_t)^2)\\ g=\nabla_{x_t}\log\hat q(y|x_t)|_{x_t=\mu_t} \end{aligned}\tag{1.4} q^​(xt​∣xt+1​,y)≈N(μt​+∑t​g,(∑t​)2)g=∇xt​​logq^​(y∣xt​)∣xt​=μt​​​(1.4)

为了获得 ∇ x t log ⁡ q ^ ( y ∣ x t ) \nabla_{x_t}\log\hat q(y|x_t) ∇xt​​logq^​(y∣xt​),Classifier Guidance Diffusion在训练好的Diffusion model的基础上额外训练了一个分类头。

假设 x t ≈ μ t x_t \approx\mu_t xt​≈μt​,则Classifier Guidance Diffusion的反向过程为:

其中 p ϕ ( y ∣ x t ) = q ^ ( y ∣ x t ) p_ \phi(y|x_t)=\hat q(y|x_t) pϕ​(y∣xt​)=q^​(y∣xt​), s s s为一个超参数。

式1.4有个问题,当方差 ∑ \sum ∑取值为0时, ∑ ∇ x t log ⁡ q ^ ( y ∣ x t ) {\sum}\nabla_{x_t}\log\hat q(y|x_t) ∑∇xt​​logq^​(y∣xt​)取值将为0,无法控制生成指定条件的图像。因此式1.4不适用于DDIM等确定性采样的扩散模型

在推导DDIM的采样公式前,我们先了解一下用Tweedie方法做参数估计的流程。

Tweedie方法主要用于指数族概率分布的参数估计,而高斯分布属于指数族概率分布,自然也适用。假设有一批样本 z z z,则利用样本 z z z估计高斯分布 N ( Z ; μ , ∑ 2 ) \mathcal N(Z;\mu,{\sum}^2) N(Z;μ,∑2)的均值 μ \mu μ的公式为

E [ μ ∣ z ] = z + ∑ 2 ∇ z log ⁡ p ( z ) (1.5) E[\mu|z]=z+{\sum}^2\nabla_z\log p(z)\tag{1.5} E[μ∣z]=z+∑2∇z​logp(z)(1.5)

已知DDPM、DDIM的前向过程有

q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) (1.6) q(x_t|x_0)=\mathcal N(x_t;\sqrt{\bar \alpha_t}x_0,(1-\bar\alpha_t)\mathcal I)\tag{1.6} q(xt​∣x0​)=N(xt​;αˉt​ ​x0​,(1−αˉt​)I)(1.6)

结合式1.5、1.6可得

α ˉ t x 0 = x t + ( 1 − α ˉ t ) ∇ x t log ⁡ p ( x t ) \begin{aligned} \sqrt{\bar \alpha_t}x_0=x_t+(1-\bar\alpha_t)\nabla_{x_t}\log p(x_t) \end{aligned} αˉt​ ​x0​=xt​+(1−αˉt​)∇xt​​logp(xt​)​
进而有
x t = α ˉ t x 0 − ( 1 − α ˉ t ) ∇ x t log ⁡ p ( x t ) (1.7) x_t=\sqrt{\bar \alpha_t}x_0-(1-\bar\alpha_t)\nabla_{x_t}\log p(x_t)\tag{1.7} xt​=αˉt​ ​x0​−(1−αˉt​)∇xt​​logp(xt​)(1.7)
设 ϵ t \epsilon_t ϵt​服从标准正态分布,则从式1.6可知

x t = α ˉ t x 0 + 1 − α ˉ t ϵ t (1.8) x_t=\sqrt{\bar \alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon_t\tag{1.8} xt​=αˉt​ ​x0​+1−αˉt​ ​ϵt​(1.8)

结合式1.7、1.8,则有

∇ x t log ⁡ p ( x t ) = − 1 1 − α ˉ t ϵ t (1.9) \nabla_{x_t}\log p(x_t)=-\frac{1}{\sqrt{1-\bar\alpha_t}}\epsilon_t\tag{1.9} ∇xt​​logp(xt​)=−1−αˉt​ ​1​ϵt​(1.9)

已知DDIM的采样公式为

x t − 1 = α ˉ t − 1 x t − 1 − α ˉ t ϵ θ ( x t ) α ˉ t + 1 − α ˉ t − δ t 2 ϵ θ ( x t ) (2.0) x_{t-1}=\sqrt{\bar \alpha_{t-1}}\frac{x_t-\sqrt{1-\bar \alpha_t}\epsilon_\theta(x_t)}{\sqrt{\bar\alpha_t}}+\sqrt{1-\bar\alpha_{t}-\delta_t^2}\epsilon_\theta(x_t)\tag{2.0} xt−1​=αˉt−1​ ​αˉt​ ​xt​−1−αˉt​ ​ϵθ​(xt​)​+1−αˉt​−δt2​ ​ϵθ​(xt​)(2.0)

结合式1.9、2.0可将DDIM的采样公式转变为

x t − 1 = α ˉ t − 1 x t − 1 − α ˉ t ( − 1 − α ˉ t ∇ x t log ⁡ p ( x t ) ) α ˉ t + 1 − α ˉ t − δ t 2 ( − 1 − α ˉ t ∇ x t log ⁡ p ( x t ) ) (2.1) x_{t-1}=\sqrt{\bar \alpha_{t-1}}\frac{x_t-\sqrt{1-\bar \alpha_t}(-\sqrt{1-\bar\alpha_t}\nabla_{x_t}\log p(x_t))}{\sqrt{\bar\alpha_t}}+\sqrt{1-\bar\alpha_{t}-\delta_t^2}(-\sqrt{1-\bar\alpha_t}\nabla_{x_t}\log p(x_t))\tag{2.1} xt−1​=αˉt−1​ ​αˉt​ ​xt​−1−αˉt​ ​(−1−αˉt​ ​∇xt​​logp(xt​))​+1−αˉt​−δt2​ ​(−1−αˉt​ ​∇xt​​logp(xt​))(2.1)

我们只需要将其中的 ∇ x t log ⁡ p ( x t ) \nabla_{x_t}\log p(x_t) ∇xt​​logp(xt​)替换为 ∇ x t log ⁡ p ( x t ∣ y ) \nabla_{x_t}\log p(x_t|y) ∇xt​​logp(xt​∣y),即可引入条件 y y y来控制DDIM的生成过程,利用贝叶斯定理,我们有

log ⁡ p ( x t ∣ y ) = log ⁡ p ( y ∣ x t ) + log ⁡ p ( x t ) − log ⁡ p ( y ) ∇ x t log ⁡ p ( x t ∣ y ) = ∇ x t log ⁡ p ( y ∣ x t ) + ∇ x t log ⁡ p ( x t ) − ∇ x t log ⁡ p ( y ) = ∇ x t log ⁡ p ( y ∣ x t ) + ∇ x t log ⁡ p ( x t ) = ∇ x t log ⁡ p ( y ∣ x t ) − 1 1 − α ˉ t ϵ t (2.2) \begin{aligned} \log p(x_t|y)&=\log p(y|x_t)+\log p(x_t)-\log p(y)\\ \nabla_{x_t}\log p(x_t|y)&=\nabla_{x_t}\log p(y|x_t)+\nabla_{x_t}\log p(x_t)-\nabla_{x_t}\log p(y)\\ &=\nabla_{x_t}\log p(y|x_t)+\nabla_{x_t}\log p(x_t)\\ &=\nabla_{x_t}\log p(y|x_t)-\frac{1}{\sqrt{1-\bar\alpha_t}}\epsilon_t \end{aligned}\tag{2.2} logp(xt​∣y)∇xt​​logp(xt​∣y)​=logp(y∣xt​)+logp(xt​)−logp(y)=∇xt​​logp(y∣xt​)+∇xt​​logp(xt​)−∇xt​​logp(y)=∇xt​​logp(y∣xt​)+∇xt​​logp(xt​)=∇xt​​logp(y∣xt​)−1−αˉt​ ​1​ϵt​​(2.2)
则有

− 1 − α ˉ t ∇ x t log ⁡ p ( x t ∣ y ) = ϵ t − 1 − α ˉ t ∇ x t log ⁡ p ( y ∣ x t ) (2.3) -\sqrt{1-\bar\alpha_t}\nabla_{x_t}\log p(x_t|y)=\epsilon_t-\sqrt{1-\bar\alpha_t}\nabla_{x_t}\log p(y|x_t)\tag{2.3} −1−αˉt​ ​∇xt​​logp(xt​∣y)=ϵt​−1−αˉt​ ​∇xt​​logp(y∣xt​)(2.3)

至此,我们可以得到DDIM的采样流程为

对于DDIM等确定性采样的扩散模型,其应在训练好的Diffusion model的基础上额外训练了一个分类头,从而转变为Classifier Guidance Diffusion。

条件扩散模型的训练目标

注意到 q ^ ( x t ∣ x t + 1 ) = q ( x t ∣ x t + 1 ) \hat q(x_t|x_{t+1})=q(x_t|x_{t+1}) q^​(xt​∣xt+1​)=q(xt​∣xt+1​),并且上述的推导过程并没有改变 q ( x t ∣ x t + 1 ) 、 q ( x t + 1 ∣ x t ) q(x_t|x_{t+1})、q(x_{t+1}|x_t) q(xt​∣xt+1​)、q(xt+1​∣xt​)的形式,因此Classifier Guidance Diffusion的训练目标与DDPM、DDIM是一致的,都可以拟合训练数据。

更多推荐

深度学习(生成式模型)——Classifier Guidance Diffusion

本文发布于:2023-11-15 17:57:45,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1603787.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:深度   模型   Classifier   Guidance   Diffusion

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!