当前位置: 首页 > article >正文

Gumbel Softmax重参数和SF估计(Score Function Estimator,VAE/GAN/Policy Gradient中的重参数)

Gumbel Softmax

We derive the probability density function of the Gumbel-Softmax distribution with probabilities π 1 , … , π k \pi_1, \ldots, \pi_k π1,,πk and temperature τ \tau τ. We first define the logits x i = log ⁡ π i x_i = \log \pi_i xi=logπi, and Gumbel samples g 1 , … , g k g_1, \ldots, g_k g1,,gk, where g i ∼ Gumbel ( 0 , 1 ) g_i \sim \text{Gumbel}(0, 1) giGumbel(0,1) Gumbel ( 0 , 1 ) \text{Gumbel}(0, 1) Gumbel(0,1) stands for sampling from uniform distribution U ( 0 , 1 ) \text{U}(0,1) U(0,1) to get u i u_i ui first, then g i = − l o g ( − l o g ( u i ) ) g_i=-log(-log(u_i)) gi=log(log(ui))】. A sample from the Gumbel-Softmax can then be computed as:

y i = exp ⁡ ( ( x i + g i ) / τ ) ∑ j = 1 k exp ⁡ ( ( x j + g j ) / τ ) for  i = 1 , … , k \begin{equation} y_i = \frac{\exp((x_i + g_i)/\tau)}{\sum_{j=1}^{k} \exp((x_j + g_j)/\tau)} \quad \text{for } i = 1, \ldots, k \end{equation} yi=j=1kexp((xj+gj)/τ)exp((xi+gi)/τ)for i=1,,k

Gumbel Code

import numpy as np

def sample_gumbel(shape, eps=1e-20):
    """
    从 Gumbel(0,1) 分布中采样
    :param shape: 采样的形状
    :param eps: 防止 log(0) 的小值
    :return: 从 Gumbel(0,1) 分布中采样的值
    """
    U = np.random.uniform(0, 1, shape)
    return -np.log(-np.log(U + eps) + eps)

# 示例:采样一个标量
sample = sample_gumbel(())
print("Sample from Gumbel(0,1):", sample)

# 示例:采样一个形状为 (3, 3) 的数组
samples = sample_gumbel((3, 3))
print("Samples from Gumbel(0,1):\n", samples)

Gumbel Softmax的意义

简单来说,Gumbel Max就是发现:
P [ arg ⁡ max ⁡ ( x + ϵ ) = i ] = softmax ( x ) i , ϵ ∼ Gumbel Noise \begin{equation} P[\arg\max(x + \epsilon) = i] = \text{softmax}(x)_i, \quad \epsilon \sim \text{Gumbel Noise} \end{equation} P[argmax(x+ϵ)=i]=softmax(x)i,ϵGumbel Noise

怎么理解这个结果呢?首先,这里的 ϵ ∼ Gumbel Noise \epsilon \sim \text{Gumbel Noise} ϵGumbel Noise 是指 ϵ \epsilon ϵ 的每个分量都是从 G u m b e l 分布 \textcolor{green}{Gumbel分布} Gumbel分布 独立重复采样出来的;接着,我们知道给定向量 x x x,本来 arg ⁡ max ⁡ ( x ) \arg\max(x) argmax(x) 是确定的结果,但加了随机噪声 ϵ \epsilon ϵ 之后, arg ⁡ max ⁡ ( x + ϵ ) \arg\max(x + \epsilon) argmax(x+ϵ) 的结果也带有随机性了,于是每个 i i i 都有自己的概率;最后,Gumbel Max告诉我们,如果加的是Gumbel噪声,那么 i i i 的出现概率正好是 softmax ( x ) i \text{softmax}(x)_i softmax(x)i

Gumbel Max最直接的作用,就是提供了一种从 softmax ( x ) \text{softmax}(x) softmax(x) 中采样的方式,当然如果单纯采样还有更简单的方法,没必要“杀鸡用牛刀”。Gumbel Max最大的价值是“重参数化(Reparameterization)”,它将问题的随机性从带参数 α \alpha α 的离散分布转移到了不带参数的 ϵ \epsilon ϵ 上,再结合 softmax \text{softmax} softmax arg ⁡ max ⁡ \arg\max argmax 的光滑近似,我们得到 softmax ( x + ϵ ) \text{softmax}(x + \epsilon) softmax(x+ϵ) 是Gumbel Max的光滑近似,这便是Gumbel Softmax,是训练“离散采样模块中带有可学参数”的模型的常用技巧。

Score Function Estimator(SF估计)

现在我们得到了梯度的一个估计式,称为“SF估计”,全称是Score Function Estimator,这是对原来损失函数的最朴素的估计,在强化学习中 z z z代表着策略梯度,所以有时候也直接称上述估计为REINFORCE。要注意,对离散情形的损失函数数重新推导一遍,结果也是一样的,也就是说,上述结果是通用的,不区分 z z z是连续变量还是离散变量。现在我们可以直接从 p θ ( z ) p_\theta(z) pθ(z)中采样若干点来估计下面公式的值,不用担心会不会没梯度,因为下面公式本身就是梯度了。这刚好是Policy Gradient中的公式

∂ ∂ θ ∫ p θ ( z ) f ( z ) d z = ∫ ∂ ∂ θ p θ ( z ) f ( z ) d z = ∫ p θ ( z ) f ( z ) p θ ( z ) ∂ ∂ θ p θ ( z ) d z = E p θ ( z ) [ f ( z ) ∂ ∂ θ log ⁡ p θ ( z ) ] = E p θ ( z ) [ f ( z ) ∂ ∂ θ log ⁡ p θ ( z ) ] \begin{equation} \begin{aligned} \frac{\partial}{\partial \theta} \int p_\theta(z) f(z) dz &= \int \frac{\partial}{\partial \theta} p_\theta(z) f(z) dz \\ &= \int p_\theta(z) \frac{f(z)}{p_\theta(z)} \frac{\partial}{\partial \theta} p_\theta(z) dz \\ &= \mathbb{E}_{p_\theta(z)} \left[ f(z) \frac{\partial}{\partial \theta} \log p_\theta(z) \right] \\ &= \mathbb{E}_{p_\theta(z)} \left[ f(z) \frac{\partial}{\partial \theta} \log p_\theta(z) \right] \end{aligned} \end{equation} θpθ(z)f(z)dz=θpθ(z)f(z)dz=pθ(z)pθ(z)f(z)θpθ(z)dz=Epθ(z)[f(z)θlogpθ(z)]=Epθ(z)[f(z)θlogpθ(z)]

漫谈重参数:从正态分布到Gumbel Softmax

来自 https://spaces.ac.cn/archives/6705
在这里插入图片描述

通向概率分布之路:盘点Softmax及其替代品

来自https://spaces.ac.cn/archives/10145/comment-page-1
在这里插入图片描述
在这里插入图片描述


http://www.kler.cn/a/557808.html

相关文章:

  • vue中json-server及mockjs后端接口模拟
  • 算法-栈和队列篇04-滑动窗口最大值
  • 深入理解 lua_KFunction 和 lua_CFunction
  • cocos2dx Win10环境搭建(VS2019)
  • 2.1作业
  • 25轻化工程研究生复试面试问题汇总 轻化工程专业知识问题很全! 轻化工程复试全流程攻略 轻化工程考研复试真题汇总
  • linux常用基础命令_最新版
  • Embedding模型
  • excel中VBA宏的使用方法?
  • nginx 反向代理 配置请求路由
  • uniapp封装请求
  • 在线办公小程序(springboot论文源码调试讲解)
  • 伦敦金库彻底断供的连锁反应推演(截至2025年02月22日)
  • BFS算法解决最短路径问题(典型算法思想)—— OJ例题算法解析思路
  • 深入理解设计模式之策略模式
  • JDBC连接保姆级教程
  • Redis数据结构总结-quickList
  • 漏扫问题-服务器中间件版本信息泄露(消除/隐藏Nginx版本号)
  • 一文说清楚Java中的volatile修饰符
  • 图解JVM-1. JVM与Java体系结构