当前位置: 首页 > article >正文

使用Python和PyTorch实现了一个简单的生成对抗网络(GAN)用于生成应力值图像

以下是一个使用Python和PyTorch实现了一个简单的生成对抗网络(GAN)用于生成应力值图像,并添加了显示正确颜色条的功能。

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import make_grid
import matplotlib.colors as mcolors

# 定义生成器
class Generator(nn.Module):
    def __init__(self, input_size, output_size):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(True),
            nn.Linear(128, output_size),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

# 定义判别器
class Discriminator(nn.Module):
    def __init__(self, input_size):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# 超参数设置
input_size = 100  # 生成器输入大小
output_size = 64 * 64  # 生成图像的大小(假设为64x64)
num_epochs = 50
batch_size = 64
lr = 0.0002
beta1 = 0.5

# 初始化生成器和判别器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(input_size, output_size).to(device)
discriminator = Discriminator(output_size).to(device)

# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

# 生成随机噪声
def generate_noise(batch_size, input_size):
    return torch.randn(batch_size, input_size).to(device)

# 加载真实数据(假设为应力值图像数据)
# 这里需要你根据实际情况加载你的训练集数据
# 示例代码中用随机数据代替
real_data = torch.randn(1000, output_size).to(device)

# 训练过程
for epoch in range(num_epochs):
    # 训练判别器
    optimizer_D.zero_grad()
    real_labels = torch.ones(batch_size, 1).to(device)
    real_outputs = discriminator(real_data[:batch_size])
    d_loss_real = criterion(real_outputs, real_labels)

    noise = generate_noise(batch_size, input_size)
    fake_data = generator(noise)
    fake_labels = torch.zeros(batch_size, 1).to(device)
    fake_outputs = discriminator(fake_data.detach())
    d_loss_fake = criterion(fake_outputs, fake_labels)

    d_loss = d_loss_real + d_loss_fake
    d_loss.backward()
    optimizer_D.step()

    # 训练生成器
    optimizer_G.zero_grad()
    noise = generate_noise(batch_size, input_size)
    fake_data = generator(noise)
    fake_outputs = discriminator(fake_data)
    g_loss = criterion(fake_outputs, real_labels)
    g_loss.backward()
    optimizer_G.step()

    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')

# 保存训练模型
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')

# 加载训练好的生成器并生成图像
generator.load_state_dict(torch.load('generator.pth'))
generator.eval()
noise = generate_noise(1, input_size)
generated_image = generator(noise).cpu().view(1, 1, 64, 64)

# 显示生成的图像和颜色条
plt.imshow(generated_image.squeeze(), cmap='viridis')
plt.colorbar()
plt.show()

请注意,上述代码仅为示例代码,实际应用中需要根据你的具体需求进行调整,包括数据加载、网络结构调整等。同时,应力值图像的颜色条显示部分使用了 matplotlib 的默认颜色映射 viridis,你可以根据实际需求选择合适的颜色映射。


http://www.kler.cn/a/598247.html

相关文章:

  • 正则表达式基本语法和Java中的简单使用
  • fastapi 实践(三)Swagger Docs
  • STM32基础教程——PWM驱动LED呼吸灯
  • AIGC 新势力:探秘海螺 AI 与蓝耘 MaaS 平台的协同创新之旅
  • 【Jwt】详解认证登录的数字签名
  • 牛客网【模板】二维差分(详解)c++
  • 【JavaEE】网络编程socket
  • Java学习路线(便于理解)
  • PostgreSQL_数据使用与日数据分享
  • C语言-访问者模式详解与实践
  • Enovia许可分析的自动化解决方案
  • 程序代码篇---Pyqt的密码界面
  • Agent TARS开源多模态 AI 代理的革命性突破
  • B树和 B+树
  • Security如何复制粘贴
  • Scikit-learn模型构建全流程解析:从数据预处理到超参数调优
  • 矩阵键盘原理与单片机驱动设计详解—端口反转法(下) | 零基础入门STM32第七十八步
  • 可视化操作界面,工程项目管理软件让复杂项目管理变简单
  • AWS SAP学习笔记-概念
  • 2025最新docker教程(四)