使用Python和PyTorch实现了一个简单的生成对抗网络(GAN)用于生成应力值图像
以下是一个使用Python和PyTorch实现了一个简单的生成对抗网络(GAN)用于生成应力值图像,并添加了显示正确颜色条的功能。
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import make_grid
import matplotlib.colors as mcolors
# 定义生成器
class Generator(nn.Module):
def __init__(self, input_size, output_size):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_size, 128),
nn.ReLU(True),
nn.Linear(128, output_size),
nn.Tanh()
)
def forward(self, x):
return self.model(x)
# 定义判别器
class Discriminator(nn.Module):
def __init__(self, input_size):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_size, 128),
nn.ReLU(True),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
# 超参数设置
input_size = 100 # 生成器输入大小
output_size = 64 * 64 # 生成图像的大小(假设为64x64)
num_epochs = 50
batch_size = 64
lr = 0.0002
beta1 = 0.5
# 初始化生成器和判别器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(input_size, output_size).to(device)
discriminator = Discriminator(output_size).to(device)
# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
# 生成随机噪声
def generate_noise(batch_size, input_size):
return torch.randn(batch_size, input_size).to(device)
# 加载真实数据(假设为应力值图像数据)
# 这里需要你根据实际情况加载你的训练集数据
# 示例代码中用随机数据代替
real_data = torch.randn(1000, output_size).to(device)
# 训练过程
for epoch in range(num_epochs):
# 训练判别器
optimizer_D.zero_grad()
real_labels = torch.ones(batch_size, 1).to(device)
real_outputs = discriminator(real_data[:batch_size])
d_loss_real = criterion(real_outputs, real_labels)
noise = generate_noise(batch_size, input_size)
fake_data = generator(noise)
fake_labels = torch.zeros(batch_size, 1).to(device)
fake_outputs = discriminator(fake_data.detach())
d_loss_fake = criterion(fake_outputs, fake_labels)
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
optimizer_D.step()
# 训练生成器
optimizer_G.zero_grad()
noise = generate_noise(batch_size, input_size)
fake_data = generator(noise)
fake_outputs = discriminator(fake_data)
g_loss = criterion(fake_outputs, real_labels)
g_loss.backward()
optimizer_G.step()
if (epoch + 1) % 10 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')
# 保存训练模型
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')
# 加载训练好的生成器并生成图像
generator.load_state_dict(torch.load('generator.pth'))
generator.eval()
noise = generate_noise(1, input_size)
generated_image = generator(noise).cpu().view(1, 1, 64, 64)
# 显示生成的图像和颜色条
plt.imshow(generated_image.squeeze(), cmap='viridis')
plt.colorbar()
plt.show()
请注意,上述代码仅为示例代码,实际应用中需要根据你的具体需求进行调整,包括数据加载、网络结构调整等。同时,应力值图像的颜色条显示部分使用了 matplotlib
的默认颜色映射 viridis
,你可以根据实际需求选择合适的颜色映射。