当前位置: 首页 > article >正文

基于矩阵乘积态的生成模型:量子力学与生成任务的结合

🏡作者主页:点击! 

🤖编程探索专栏:点击!

⏰️创作时间:2024年12月23日11点02分


神秘男子影,
  秘而不宣藏。
泣意深不见,
男子自持重,
   子夜独自沉。


文章源链接(有视频):Aspiringcode - 编程抱负 即刻实现传知代码只专注开箱即用的代码icon-default.png?t=O83Ahttps://www.aspiringcode.com/content?id=17212664745068&uid=c77f3951d4d6474f97aa889c88064311

概述

生成模型,通过从数据中学习联合概率分布并据此生成样本,是机器学习和人工智能中的一个重要任务。受量子物理学中概率解释的启发,该文章提出了一种使用矩阵积状态的生成模型,这是一种最初用于描述(特别是一维)纠缠量子态的张量网络。其模型享有类似于密度矩阵重正化群方法的高效学习能力,该方法允许动态调整张量的维度,并提供了一种高效的直接采样方法用于生成任务。本文试图复现该文章的工作,利用该文章的思想,方法去实现MNIST手写数字的生成任务。

  • Han Z-Y, Wang J, Fan H, et al. Unsupervised Generative Modeling Using Matrix Product States[J]. Physical Review X, 2018, 8(3): 031012.

演示效果

方法

量子力学的概率解释自然地建议使用量子态来建模数据分布。假设我们将概率分布编码到一个量子波函数Ψ(v)Ψ(v) 中,测量会使其坍缩并生成一个结果 v=(v1,v2,…,vN)v=(v1,v2,…,vN),其概率与 ∣Ψ(v)∣2∣Ψ(v)∣2 成正比。受到量子力学生成特性的启发,通过以下方式表示模型概率分布:

其中值得说明的是Ψ(v1,v2,…,vN)Ψ(v1,v2,…,vN)和P(v1,v2,…,vN)P(v1,v2,…,vN)都是张量(tensor),因为他受制于多个不同的指标v1,v2,…,vNv1,v2,…,vN,此外ZZ作为配分函数代表的是归一化系数,即Z=∑vi∣Ψ(v1,v2,…,vN)∣2Z=∑vi∣Ψ(v1,v2,…,vN)∣2.

如何对∣Ψ(v1,v2,…,vN)∣2∣Ψ(v1,v2,…,vN)∣2进行合适的建模使得模型既不复杂,又在一定程度上能够表示更多不同种类的构型成为现在需要解决的问题。许多已经开发的表示方法和算法可以用于高效的概率建模。在这里,我们使用矩阵积状态(MPS)对波函数进行参数化:

上面的图示意思为,左边是我们需要表示的波函数,线代表它依赖的指标(或者变量),右边则是对应的MPS表示,两个方括号直接的连线代表求和,即将对应的指标(或者变量求和,类似于矩阵的乘积)进行收缩。我们可以看出我们把一个复杂的波函数变成了有限个3指标张量的收缩。

实现

导入训练集(MNIST)

1000张MNIST图像已存储为mnist784_bin_1000.npy。每张图像包含 n=28×28=784n=28×28=784 个像素,每个像素的取值为0或1。每张图像被视为维度为 2n2n 的希尔伯特空间中的一个乘积态。

n = 784 
m = 1000
data = np.load("mnist784_bin_1000.npy").astype(np.int32)
data = data[:m,:]
data = torch.LongTensor(data)\
plt.figure(figsize=(10,2))
imgs = data.cpu().reshape([-1,28,28])
_, ax = plt.subplots(2,10)
for i in range(2): 
    for j in range(10):
        index = i * 2 + j
        if(a >= imgs.shape[0]):
            break
        ax[i][j].imshow(imgs[index,:,:],cmap='bone')
        ax[i][j].set_xticks([])
        ax[i][j].set_yticks([])
plt.show()

这可以让我们观察以下MNIST数据集的样子

定义MPS

现在我们要构造一个初始的MPS, 根据上面的阐述,我们的MPS是由一系列3指标的张量的所构成的,如下所示

解释一下,χχ表示的是最大截断指标,也就是为了控制整个MPS的维度,防止维度灾难;左右两边的11可以认为是一个空指标,即没有任何含义,只是为了在运算过程中方便,实际是没有这个指标也可以;下面的22表示的是这些指标的维度是2,因为我们只是考虑的像素点只有黑白(1,0)两种状态,也因此下面的22的数目需要28∗28=78428∗28=784个。

chi = 30 
mydevice = 'cuda' if torch.cuda.is_available() else torch.device("cpu")
print(mydevice)
data = data.to(mydevice)
bond_dims = [chi for i in range(n-1)]+[1]
tensors= [ torch.randn(bond_dims[i-1],2,bond_dims[i],device=mydevice) for i in range(n)]

我们可以输出从而看到这些张量的输出维度

概率计算

概率计算可以遵循前面的Born公式,即

在这里,带有一个小边(常称之为脚)是一个向量,代表的是对应像素的状态,是一个二维向量,用来表示对应的像素是黑还是白

现在难以计算的是配分函数,即

这个东西,这涉及到张量网络的缩并,在张量网络这个领域中由非常多的缩并方式,一个常用的方法是正交化,即把MPS右边的那些三阶张量全部正交化使得他们收缩刚好是一个单位张量。这个过程如下

通过不断的对左边的张量作用QR分解从而使得左边张量全部正交化(黄色的)。据此我们可以计算出对应的波函数

def getPsi():
    psi = torch.ones([m, 1, 1], device=mydevice)
    for site in range(n):
        selected_tensor = tensors[site][:, data[:, site], :].permute(1, 0, 2)
        psi = torch.matmul(psi, selected_tensor)
    return psi

生成图片

生成图片的过程可以采用条件概率的方法,即先采样一个边缘概率,再从这个边缘概率对应的变量继续采样,重复这个过程即可

核心代码为

def generateSamples(batch):
    n = 784
    samples = torch.zeros([batch, n],device=mydevice)
    for site in range(n - 1):
        orthogonalize(site, True) 
    for s in range(batch):
        vec = torch.ones(1,1,device=mydevice)
        for site in range(n-1, -1, -1):
            vec = (tensors[site].view(-1, bond_dims[site]) @ vec).view(-1, 2)
            p0 = vec[:, 0].norm()**2 / (vec.norm()**2)
            x = (0 if np.random.rand() < p0 else 1)
            vec = vec[:, x]
            samples[s][site] = x
    return samples

训练

因此我们可以根据该模型去训练一个生成模型,损失函数是交叉熵损失即将模型的概率与图片本身分布的概率相拟合。交叉熵损失函数为

除以m就相当于是平均。

ψ′ψ′和Z′Z′是比较好得到,因为对整体的MPS对各个参数是分立。比如说对第100个张量的里面的参数求梯度,那么其他的张量相对于第100个来说都是常数,这跟对矩阵的乘积求梯度是一样的,并没有参数的过度耦合,因此训练过程中核心代码为

for i in ...
      ....
      gradients[:, i, :] = torch.sum(left_vec.permute(0, 2, 1) @ right_vec.permute(0, 2, 1) / psi, 0) 
gradients = 2.0 * (gradients / m - tensors[site])  
tensors[site] += learning_rate * gradients/gradients.norm()

使用方式

  • jupyter notebook 运行

安装依赖

  • Python 3.11.4
  • torch 2.0.1

http://www.kler.cn/a/449473.html

相关文章:

  • Linux网络——网络基础
  • 关于uni-forms组件的bug【提交的字段[‘*‘]在数据库中并不存在】
  • 什么是MVCC?
  • javaEE--计算机是如何工作的-1
  • 通航飞机(通用航空飞机)的软件关键技术
  • 每天40分玩转Django:Django部署
  • Transformer自注意力机制详解
  • Rust之抽空学习系列(五)—— 所有权(上)
  • 《点点之歌》“意外”诞生记
  • 【学术小白的学习之路】基于情感词典的中文句子情感分析(代码词典获取在结尾)
  • springboot+vue的高校宿舍管理系统
  • iOS - 超好用的隐私清单修复脚本(持续更新)
  • DDoS防护中的流量清洗与智能调度
  • 云原生服务网格Istio实战
  • Spring学习(一)——Sping-XML
  • Sigrity Speed2000 仿真分析教程与实例分析文件路径
  • 【漫话机器学习系列】019.布里(莱)尔分数(Birer score)
  • 前端开发 之 12个鼠标交互特效下【附完整源码】
  • Pinia与Vuex的区别
  • ARM异常处理 M33
  • 单片机:实现自动关机电路(附带源码)
  • 【自动化】深度解析仓库存储UI自动化
  • Android简洁缩放Matrix实现图像马赛克,Kotlin
  • ubuntu20.04安装imwheel实现鼠标滚轮调速
  • Kubernetes(K8s)学习笔记
  • 基于YOLOv5的智能水域监测系统:从目标检测到自动报告生成