当前位置: 首页 > article >正文

deep generative model stanford lecture note3 --- latent variable

1 Introduction

自回归模型随着gpt的出现取得很大的成功,还是有很多工程上的问题并不是很适合使用自回归模型:
1)自回归需要的算力太大,满足不了实时性要求:例如在自动驾驶的轨迹预测任务中,如果要用纯自回归的世界模型,耗时太大;
2)要求数据天然有时序性:很多图像任务并没有严格的序列生成的要求;
这个部分开始用隐变量的方式来进行建模。

2 特征提取和线形回归

自动驾驶和机器人中的很多任务,是通过感知的环境输入, 然后进行特征提取,最后用线形回归来预测和生成指令。

在这里插入图片描述
但是这种方式因为采用了非常简单的单高斯分布来估计指令,这个时候有几种提高的方式:
1)提高特征的表达能力
1.1)如果特征提取的模型(一般是transformer)记忆力足够强大,哪怕后面接了单峰高斯估计也能有一个比较好的拟合效果;直觉来说,就是把所有的情况都记住了。高质量的特征能够在一定程度上“预处理”复杂性。
1.2)采用anchor based的query来生成不同的feature,降低拟合难度。
2)提高概率分布的表达能力
2.1)采用混合高斯叠加的概率分布,来生成复杂的概率分布。
因为提高特征表达能力往往是多模态相关的工作,这里我门进行跳过,更加关注通过提高概率分布的表达能力这个方面。

3 vae 模型

z是隐变量,需要用模型构建z存在的情况下 p ( x ∣ z , θ ) p(x|z,\theta) p(xz,θ)的概率。
在这里插入图片描述
按照note2的内容,loss设计的时候,满足极大似然就可以 l o g P θ ( x ) logP_{\theta}(x) logPθ(x)
现在的问题是,每种z都有一定的概率能生成x。可以采用普查的方式,或者采用抽样的方式

l o g P θ ( x ) = 1 D ( z ) ∑ z ∈ D P ( x , z ; θ ) \begin{aligned} logP_{\theta}(x)=\frac{1}{D(z)}\sum_{z \in D}P(x,z;\theta) \end{aligned} logPθ(x)=D(z)1zDP(x,z;θ)

因为z本身是连续分布,采用普查的方式来采样无穷个 z j z^j zj显然是不现实的。我们只需要将和x相关性较高的z(重要性采样)找出来就好。
l o g P θ ( x ) = l o g ∑ j = 1 k q ( z ( j ) ) q ( z ( j ) ) P ( x , z ; θ ) = l o g E x − q ( z ) P ( x , z ; θ ) q ( z ( j ) ) \begin{aligned} logP_{\theta}(x) & = log\sum_{j=1}^k \frac{q(z^{(j)})}{q(z^{(j)})} P(x,z;\theta) \\ & = logE_{x-q(z)}\frac{P(x,z;\theta)}{q(z^{(j)})} \end{aligned} logPθ(x)=logj=1kq(z(j))q(z(j))P(x,z;θ)=logExq(z)q(z(j))P(x,z;θ)

在这里插入图片描述
对于log这种凸函数,满足 l o g E [ x ] > E [ l o g ( x ) ] logE[x]>E[log(x)] logE[x]>E[log(x)],可以对上面这个式子进行变换
l o g P θ ( x ) = l o g E x − q ( z ) P ( x , z ; θ ) q ( z ( j ) ) ≥ E x − q ( z ) l o g P ( x , z ; θ ) q ( z ( j ) ) = ∑ j = 1 k q ( z ( j ) ) l o g P ( x , z ; θ ) q ( z ( j ) ) = ∑ j = 1 k ( q ( z ( j ) ) l o g P ( x , z ; θ ) − q ( z ( j ) ) l o g q ( z ( j ) ) ) = E L B O \begin{aligned} logP_{\theta}(x) & = logE_{x-q(z)}\frac{P(x,z;\theta)}{q(z^{(j)})} \\ & \ge E_{x-q(z)}log\frac{P(x,z;\theta)}{q(z^{(j)})} \\ & = \sum_{j=1}^kq(z^{(j)})log\frac{P(x,z;\theta)}{q(z^{(j)})} \\ & = \sum_{j=1}^k(q(z^{(j)})logP(x,z;\theta)-q(z^{(j)})logq(z^{(j)}))=ELBO \end{aligned} logPθ(x)=logExq(z)q(z(j))P(x,z;θ)Exq(z)logq(z(j))P(x,z;θ)=j=1kq(z(j))logq(z(j))P(x,z;θ)=j=1k(q(z(j))logP(x,z;θ)q(z(j))logq(z(j)))=ELBO

现在问题变得很简单了,我们需要搞清楚 q ( z ) q(z) q(z)的概率分布q(z)概率分布可以理解成状态变量x通过网络提取出来的特征
但是这里可能存在一个问题,我们的encoder并没有很好的把z的概率分布估计好。也就是说重要的z可能给的概率不够高,不重要的z可能给的概率太高了,所以我们还是要看一下encoder到底拟合的怎么样。显然这里就用KL散度来描述。
D K L ( q ( z ) ∣ ∣ p ( z ∣ x ; θ ) ) = ∑ z q ( z ) l o g q ( z ) p ( z ∣ x ; θ ) = ∑ z q ( z ) l o g q ( z ) p ( z , x ; θ ) / p ( x ; θ ) = ∑ z ( q ( z ) l o g q ( z ) + q ( z ) l o g p ( x ; θ ) − q ( z ) l o g p ( x , z ; θ ) ) = ∑ z q ( z ) l o g p ( x ; θ ) − ∑ z ( q ( z ) l o g q ( z ) − q ( z ) l o g p ( x , z ; θ ) ) = l o g p ( x ; θ ) − ∑ j = 1 k ( q ( z ( j ) ) l o g P ( x , z ; θ ) − q ( z ( j ) ) l o g q ( z ( j ) ) ) \begin{aligned} D_{KL}(q(z)||p(z|x;\theta))&=\sum_z q(z)log\frac{q(z)}{p(z|x;\theta)} \\ & = \sum_z q(z)log\frac{q(z)}{p(z,x;\theta)/p(x;\theta)} \\ & = \sum_z (q(z)logq(z)+q(z)logp(x;\theta)-q(z)logp(x,z;\theta)) \\ & = \sum_z q(z)logp(x;\theta) - \sum_z (q(z)logq(z)-q(z)logp(x,z;\theta)) \\ & = logp(x;\theta)- \sum_{j=1}^k(q(z^{(j)})logP(x,z;\theta)-q(z^{(j)})logq(z^{(j)})) \end{aligned} DKL(q(z)∣∣p(zx;θ))=zq(z)logp(zx;θ)q(z)=zq(z)logp(z,x;θ)/p(x;θ)q(z)=z(q(z)logq(z)+q(z)logp(x;θ)q(z)logp(x,z;θ))=zq(z)logp(x;θ)z(q(z)logq(z)q(z)logp(x,z;θ))=logp(x;θ)j=1k(q(z(j))logP(x,z;θ)q(z(j))logq(z(j)))

这个公式就是我们上面那个公式,也证明了只有我们的encoder能充分的将z的概率分布学习好的时候,才能保证最大似然估计的更好。
在这里插入图片描述
现在我们来更新一下极大似然
l o g P θ ( x ) = ∑ z ( q ( z ∣ x , ϕ ) l o g P ( x , z ; θ ) − q ( z ∣ x , ϕ ) l o g q ( z ∣ x , ϕ ) ) − D K L ( q ( z ) ∣ ∣ p ( z ∣ x ; θ ) ) \begin{aligned} logP_{\theta}(x) & = \sum_{z}(q(z|x,\phi)logP(x,z;\theta)-q(z|x,\phi)logq(z|x,\phi)) - D_{KL}(q(z)||p(z|x;\theta))\\ \end{aligned} logPθ(x)=z(q(zx,ϕ)logP(x,z;θ)q(zx,ϕ)logq(zx,ϕ))DKL(q(z)∣∣p(zx;θ))

KL散度可以积分直接得到解析解,这里直接给出公式的结果
D K L ( q ϕ ( z ∣ x ) ∣ ∣ p ( z ) ) = D K L ( N ( μ , σ ) ∣ ∣ N ( 0 , 1 ) ) = 1 2 ∑ i ( σ i 2 + μ i 2 − 1 − l n σ i 2 ) \begin{aligned} D_{KL}(q_{\phi}(z|x)||p(z)) & = D_{KL}(\mathcal{N}(\mu, \sigma)||\mathcal{N}(0, 1)) \\ & = \frac{1}{2}\sum_i(\sigma_i^2+\mu_i^2-1-ln\sigma_i^2) \end{aligned} DKL(qϕ(zx)∣∣p(z))=DKL(N(μ,σ)∣∣N(0,1))=21i(σi2+μi21lnσi2)

对于ELBO,这里只能采用mento carlo的方式进行采样计算
z ( k ) = μ ϕ ( x ) + σ ϕ ( x ) ϵ , ϵ ∼ N ( 0 , 1 ) z^{(k)}=\mu_{\phi}(x)+\sigma_{\phi}(x)ϵ, ϵ \sim \mathcal{N}(0, 1) z(k)=μϕ(x)+σϕ(x)ϵ,ϵN(0,1)

那么极大似然可以更新成
l o g P θ ( x ) = 1 K ∑ k l o g P ( x , z ( k ) ; θ ) − l o g q ( z ( k ) ∣ x ; ϕ ) ) − D K L ( q ( z ) ∣ ∣ p ( z ∣ x ; θ ) ) \begin{aligned} logP_{\theta}(x) & =\frac{1}{K} \sum_{k}logP(x,z^{(k)};\theta)-logq(z^{(k)}|x;\phi)) - D_{KL}(q(z)||p(z|x;\theta))\\ \end{aligned} logPθ(x)=K1klogP(x,z(k);θ)logq(z(k)x;ϕ))DKL(q(z)∣∣p(zx;θ))

最后我们再来看一下这个公式,这个也解答了我们再线形回归中的问题。
1)𝑞(𝑧∣𝑥,𝜙) 的熵。直观上鼓励编码器不要过于自信,即不要把 q(z∣x,ϕ) 限制在狭窄区域内,而是保留一定的不确定性以捕捉数据的多样性和内在噪声。
2)我们从z变量中多采样几个(采样k个隐变量,类似于我们采用anchor-based 的query),越有助于我们更准确的进行参数似然估计;
3)KL散度则是encoder的正则化,确保真实后验 p ( z ∣ x ; θ ) p(z|x;\theta) p(zx;θ)和先验分布一致;在线形回归任务中,我门经常对提取的特征进行特征重建,也达到了类似的效果。

4 cvae实例分析

aloha作为机器人模仿学习的重要的一项工作[1],在他们的工作中使用了cvae,让我们来看一下它是如何设计的。

References

[1] https://tonyzhaozh.github.io/aloha/


http://www.kler.cn/a/531129.html

相关文章:

  • Kamailio、MySQL、Redis、Gin后端、Vue.js前端等基于容器化部署
  • 联想拯救者Y9000P IRX8 2023 (82WK) 原厂Win11 家庭中文版系统 带一键还原功能 安装教程
  • 【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】2.12 连续数组:为什么contiguous这么重要?
  • elasticsearch8.15 高可用集群搭建(含认证Kibana)
  • R 字符串:深入理解与高效应用
  • 全面认识了解DeepSeek+利用ollama在本地部署、使用和体验deepseek-r1大模型
  • 半导体器件与物理篇7 微波二极管、量子效应和热电子器件
  • SynchronousQueue 与 LinkedBlockingQueue区别及应用场景
  • DeepSeek-R1 低成本训练的根本原因是?
  • CTF-web: php-session临时文件特性
  • Spring MVC学习——发送请求(@RequestMapping注解及请求参数绑定)
  • Android学习19 -- 手搓App
  • The Simulation技术浅析(三):数值方法
  • Hive修复分区
  • 亚博microros小车-原生ubuntu支持系列:20 ROS Robot APP建图
  • 【IocDI】_存储Bean的五大类注解及getBean的使用
  • 独立开发者的技术栈
  • 使用Pygame制作“走迷宫”游戏
  • 54【ip+端口+根目录通信】
  • 【数据分析】案例04:豆瓣电影Top250的数据分析与Web网页可视化(numpy+pandas+matplotlib+flask)
  • 六百六十六,盐豆不带盐了
  • 解决SetWindowCompositionAttribute使控件文本透明的问题
  • git中文件的状态状态切换
  • 全栈开发:使用.NET Core WebAPI构建前后端分离的核心技巧(一)
  • 代码随想录算法训练营Day35
  • Docker 部署教程jenkins