当前位置: 首页 > article >正文

吉布斯采样方法

吉布斯采样方法

对于多元分布, P ( X ) , X = [ x 1 x 2 ] P(X), X=\left[\begin{array}{l} x_1 \\ x_2 \end{array}\right] P(X),X=[x1x2]吉布斯抽样执行如下。假设很难从联合分布中抽样 P ( x 1 , x 2 ) P\left(x_1, x_2\right) P(x1,x2)但是从条件分布 P ( x 1 ∣ x 2 ) P(x_1|x_2) P(x1x2) P ( x 2 ∣ x 1 ) P(x_2|x_1) P(x2x1)中抽样是可能的。

  • X t = [ x 1 0 x 2 0 ] X^t=\left[\begin{array}{l} x_1^0 \\ x_2^0 \end{array}\right] Xt=[x10x20]开始
  • 采样 x 1 t + 1 ∼ P ( x 1 ∣ x 2 t ) x_1^{t+1} \sim P\left(x_1 \mid x_2^t\right) x1t+1P(x1x2t)
  • 采样 x 2 t + 1 ∼ P ( x 2 ∣ x 1 t + 1 ) x_2^{t+1} \sim P\left(x_2 \mid x_1^{t+1}\right) x2t+1P(x2x1t+1)
  • X t + 1 = [ x 1 t + 1 x 2 t + 1 ] X^{t+1}=\left[\begin{array}{l} x_1^{t+1} \\ x_2^{t+1} \end{array}\right] Xt+1=[x1t+1x2t+1]

删除前几个样本作为老化值。

P ( X ) = P ( x 1 , x 2 ) = 1 ∣ 2 π Σ ∣ e − 1 2 ( X − μ ) T Σ − 1 ( X − μ ) P(X)=P\left(x_1, x_2\right)=\frac{1}{\sqrt{|2 \pi \Sigma|}} e^{-\frac{1}{2}(X-\mu)^T \Sigma^{-1}(X-\mu)} P(X)=P(x1,x2)=∣2πΣ∣ 1e21(Xμ)TΣ1(Xμ)
其中, μ = [ 0 0 ] \mu=\left[\begin{array}{l} 0 \\ 0 \end{array}\right] μ=[00] Σ = [ 1 b b 1 ] \Sigma=\left[\begin{array}{ll} 1 & b \\ b & 1 \end{array}\right] Σ=[1bb1] X = [ x 1 x 2 ] X=\left[\begin{array}{l} x_1 \\ x_2 \end{array}\right] X=[x1x2] b = 0.8 b=0.8 b=0.8
条件概率由
P ( x 1 ∣ x 2 ) = N ( b x 2 , 1 − b 2 ) P ( x 2 ∣ x 1 ) = N ( b x 1 , 1 − b 2 ) \begin{aligned} & P\left(x_1 \mid x_2\right)=\mathcal{N}\left(b x_2, 1-b^2\right) \\ & P\left(x_2 \mid x_1\right)=\mathcal{N}\left(b x_1, 1-b^2\right) \end{aligned} P(x1x2)=N(bx2,1b2)P(x2x1)=N(bx1,1b2)

代码

import numpy.linalg as LA
import numpy as np
import matplotlib.pyplot as plt

def multivariate_normal(X, mu=np.array([[0, 0]]), sig=np.array([[1, 0.8], [0.8, 1]])):
    sqrt_det_2pi_sig = np.sqrt(2 * np.pi * LA.det(sig))
    sig_inv = LA.inv(sig)
    X = X[:, None, :] - mu[None, :, :]
    return np.exp(-np.matmul(np.matmul(X, np.expand_dims(sig_inv, 0)), (X.transpose(0, 2, 1)))/2)/sqrt_det_2pi_sig

x = np.linspace(-3, 3, 1000)
X = np.array(np.meshgrid(x, x)).transpose(1, 2, 0)
X = np.reshape(X, [X.shape[0] * X.shape[1], -1])
z = multivariate_normal(X)

plt.imshow(z.squeeze().reshape([x.shape[0], -1]), extent=[-10, 10, -10, 10], cmap='hot', origin='lower')
plt.contour(x, x, z.squeeze().reshape([x.shape[0], -1]), cmap='cool')
plt.title('True Bivariate Distribution')
plt.xlabel('$x_1$')
plt.ylabel('$x_2$')
plt.show()

x0 = [0, 0]
xt = x0
b = 0.8
samples = []
for i in range(100000):
    x1_t = np.random.normal(b*xt[1], 1-b*b)
    x2_t = np.random.normal(b*x1_t, 1-b*b)
    xt = [x1_t, x2_t]
    samples.append(xt)
burn_in = 1000
samples = np.array(samples[burn_in:])

im, x_, y_ = np.histogram2d(samples[:, 0], samples[:, 1], bins=100, normed=True)
plt.imshow(im, extent=[-10, 10, -10, 10], cmap='hot', origin='lower', interpolation='nearest')
plt.title('Empirical Bivariate Distribution')
plt.xlabel('$x_1$')
plt.ylabel('$x_2$')
plt.show()


http://www.kler.cn/a/16481.html

相关文章:

  • Linux源码阅读笔记-V4L2框架基础介绍
  • docker:docker: Get https://registry-1.docker.io/v2/: net/http: request canceled
  • 搭建Python2和Python3虚拟环境
  • 丹摩征文活动|丹摩智算平台使用指南
  • 2019年下半年试题二:论软件系统架构评估及其应用
  • 【Java SE】接口类型
  • 设计模式-单例模式
  • 一文搞懂PMP挣值管理那些让你头疼的公式
  • mockjs学习笔记
  • maven中的 type ,scope的作用
  • 2335. 装满杯子需要的最短总时长
  • 83. map函数()-通过函数实现对可迭代对象的操作(适合零基础)
  • 「SQL面试题库」 No_55 销售分析 I
  • ramfs, rootfsinitramfs
  • HTML(四) -- 多媒体设计
  • CCD视觉检测设备如何选择光源
  • 【面试长文】HashMap的数据结构和底层原理以及在JDK1.6、1.7和JDK8中的演变差异
  • Blender启动场景的修改
  • 资讯汇总230503
  • 哈希表企业应用-DNA的字符串检测
  • CKA/CKS/CKAD认证考试攻略
  • 【五一创作】( 字符串) 409. 最长回文串 ——【Leetcode每日一题】
  • 【LeetCood206】反转链表
  • Python小姿势 - Python学习笔记——类与对象
  • ZooKeeper安装与配置集群
  • NECCS|全国大学生英语竞赛C类|词汇和语法|词汇题|21:03~21:53