当前位置: 首页 > article >正文

【深度学习基础模型】Variational Autoencoders (VAE) 详细理解并附实现代码。

【深度学习基础模型】Variational Autoencoders (VAE) 详细理解并附实现代码

【深度学习基础模型】Variational Autoencoders (VAE) 详细理解并附实现代码


文章目录

  • 【深度学习基础模型】Variational Autoencoders (VAE) 详细理解并附实现代码
  • 1.Variational Autoencoders (VAE) 的原理和应用
    • 1.1 VAE 原理
    • 1.2 VAE 的主要特征:
    • 1.3 VAE 的应用领域:
  • 2.Python 代码实现 VAE 在遥感领域的应用
    • 2.1VAE 模型的实现
    • 2.2代码解释
  • 3.总结


参考地址:https://www.asimovinstitute.org/neural-network-zoo/
论文地址:https://arxiv.org/pdf/1312.6114v10

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!

1.Variational Autoencoders (VAE) 的原理和应用

1.1 VAE 原理

变分自编码器(Variational Autoencoder, VAE)是生成模型的一种,旨在学习输入数据的潜在概率分布VAE 与传统的自编码器(AE)相比,其核心区别在于它采用了贝叶斯方法进行推理

1.2 VAE 的主要特征:

  • 架构:与 AE 相同,VAE 也由编码器和解码器组成,但编码器的输出是潜在变量的概率分布(通常为高斯分布)
  • 重参数化技巧:为了解决标准反向传播无法有效训练模型的问题,VAE 引入了重参数化技巧。通过将潜在变量表示为固定分布(如标准正态分布)与参数化分布(均值和方差)的组合,模型能够有效学习
  • 损失函数:VAE 的损失函数由两部分组成:重构损失和 KL 散度(Kullback-Leibler Divergence)。重构损失衡量生成样本与真实样本之间的差异,而 KL 散度则确保潜在分布接近先验分布(通常是标准正态分布)。

1.3 VAE 的应用领域:

  • 图像生成:VAE 可以生成新图像,广泛应用于计算机视觉。
  • 数据插值:通过在潜在空间中进行插值,VAE 可以生成两种输入之间的过渡图像。
  • 异常检测:在学习正常数据的分布后,VAE 可以检测到异常样本。

在遥感领域,VAE 可以用于处理高维遥感数据,生成新图像,或从复杂的多光谱图像中提取潜在特征。

2.Python 代码实现 VAE 在遥感领域的应用

下面通过一个简单的 VAE 实现,演示如何在遥感图像处理中应用 VAE。

2.1VAE 模型的实现

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt

# 定义 VAE 模型
class VAE(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(VAE, self).__init__()
        
        # 编码器
        self.fc1 = nn.Linear(input_size, hidden_size)  # 输入到隐藏层
        self.fc21 = nn.Linear(hidden_size, hidden_size)  # 均值
        self.fc22 = nn.Linear(hidden_size, hidden_size)  # 对数方差
        
        # 解码器
        self.fc3 = nn.Linear(hidden_size, hidden_size)
        self.fc4 = nn.Linear(hidden_size, input_size)

    def encode(self, x):
        h1 = torch.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)  # 返回均值和对数方差

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)  # 标准差
        eps = torch.randn_like(std)  # 随机噪声
        return mu + eps * std  # 重新参数化

    def decode(self, z):
        h3 = torch.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))  # 输出为概率值 [0, 1]

    def forward(self, x):
        mu, logvar = self.encode(x)  # 编码
        z = self.reparameterize(mu, logvar)  # 重新参数化
        return self.decode(z), mu, logvar  # 解码及返回均值和对数方差

    def loss_function(self, recon_x, x, mu, logvar):
        BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')  # 重构损失
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())  # KL 散度
        return BCE + KLD  # 总损失

# 生成模拟遥感图像数据 (64 维特征)
X = np.random.rand(1000, 64)  # 1000 个样本,每个样本有 64 维光谱特征
X = torch.tensor(X, dtype=torch.float32)

# 创建数据加载器
dataset = TensorDataset(X)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 定义模型、优化器
input_size = 64
hidden_size = 32  # 隐藏层大小
vae = VAE(input_size=input_size, hidden_size=hidden_size)
optimizer = optim.Adam(vae.parameters(), lr=0.001)

# 训练 VAE 模型
num_epochs = 50
for epoch in range(num_epochs):
    for data in dataloader:
        optimizer.zero_grad()
        recon_batch, mu, logvar = vae(data[0])  # 前向传播
        loss = vae.loss_function(recon_batch, data[0], mu, logvar)  # 计算损失
        loss.backward()
        optimizer.step()
    
    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item()}')

# 使用训练好的模型进行数据生成
with torch.no_grad():
    sample = torch.randn(64)  # 从标准正态分布生成潜在变量
    generated_data = vae.decode(sample).numpy()  # 解码生成新样本

# 可视化原始数据与生成数据
plt.figure(figsize=(6, 3))
plt.subplot(1, 2, 1)
plt.title('Generated Data')
plt.imshow(generated_data[:10], aspect='auto', cmap='hot')
plt.subplot(1, 2, 2)
plt.title('Original Data Sample')
plt.imshow(X.numpy()[:10], aspect='auto', cmap='hot')
plt.show()

2.2代码解释

1.模型定义:

class VAE(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(VAE, self).__init__()
        
        # 编码器
        self.fc1 = nn.Linear(input_size, hidden_size)  # 输入到隐藏层
        self.fc21 = nn.Linear(hidden_size, hidden_size)  # 均值
        self.fc22 = nn.Linear(hidden_size, hidden_size)  # 对数方差
        
        # 解码器
        self.fc3 = nn.Linear(hidden_size, hidden_size)
        self.fc4 = nn.Linear(hidden_size, input_size)
  • VAE 类定义了编码器和解码器结构,包括均值和对数方差的输出。

2.重参数化技巧:

def reparameterize(self, mu, logvar):
    std = torch.exp(0.5 * logvar)  # 标准差
    eps = torch.randn_like(std)  # 随机噪声
    return mu + eps * std  # 重新参数化
  • 使用随机噪声与潜在变量均值和标准差结合,生成潜在表示。

3.数据生成:

X = np.random.rand(1000, 64)  # 生成 1000 个样本,每个样本有 64 维光谱特征
X = torch.tensor(X, dtype=torch.float32)
  • 模拟生成随机的遥感光谱数据。

4.数据加载器:

dataset = TensorDataset(X)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
  • 使用 DataLoader 创建批处理数据集。

5.模型训练

for epoch in range(num_epochs):
    for data in dataloader:
        optimizer.zero_grad()
        recon_batch, mu, logvar = vae(data[0])
        loss = vae.loss_function(recon_batch, data[0], mu, logvar)
        loss.backward()
        optimizer.step()
  • 使用 50 个 epoch 进行训练,计算重构损失和 KL 散度,更新权重。

6.生成新数据:

with torch.no_grad():
    sample = torch.randn(64)  # 从标准正态分布生成潜在变量
    generated_data = vae.decode(sample).numpy()  # 解码生成新样本
  • 通过从潜在空间生成样本,使用解码器生成新数据。

7.可视化:

plt.subplot(1, 2, 1)
plt.title('Generated Data')
plt.imshow(generated_data[:10], aspect='auto', cmap='hot')
  • 可视化生成的数据与原始数据的对比。

3.总结

变分自编码器(VAE)是一种强大的生成模型,能够有效学习输入数据的潜在概率分布。在遥感领域,VAE 可以用于数据生成、特征提取和异常检测等任务。通过简单的 Python 实现,我们展示了如何使用 VAE 处理遥感数据,生成新样本,并可视化结果。


http://www.kler.cn/a/323031.html

相关文章:

  • 【Hadoop实训】Hive 数据操作②
  • lua实现雪花算法
  • MySQL缓存使用率超过80%的解决方法
  • C++:基于红黑树封装map和set
  • 「Mac玩转仓颉内测版5」入门篇5 - Cangjie控制结构(上)
  • Flink Source 详解
  • 【基础知识】Go中的同步机制
  • 基于yolov8的辣椒缺陷检测系统python源码+onnx模型+评估指标曲线+精美GUI界面
  • STM32G431RBT6 VREF+与VDDA引脚
  • 计算机性能指标之MIPS
  • 极狐GitLab 17.4 重点功能解读【九】
  • Windows安全日志7关键事件ID分析
  • 95分App引领绿色消费新潮流,闲置物品焕发新生机
  • 【JS】函数柯里化
  • 「数组」离散化 / Luogu B3694(C++)
  • 畅阅读微信小程序
  • RHCS认证-Linux(RHel9)-Ansible
  • 代码随想录Day 58|拓扑排序、dijkstra算法精讲,题目:软件构建、参加科学大会
  • Linux操作系统中MongoDB
  • 前端开发之代理模式
  • Cilium + ebpf 系列文章- (六)Cilium-BGP与分发-EXTERNAL-IP
  • vue3中< keep-alive >页面实现缓存及遇到的问题
  • 【深度学习】深度学习框架有哪些及其优劣势介绍
  • 【CSS】透明度 、过渡 、动画 、渐变
  • JAVAEE如何实现网页(jsp)间的数据传输?一文总结
  • 2024 icpc 第二场网络赛题解