当前位置: 首页 > article >正文

Python和PyTorch库实现基于生成对抗网络(GAN)将小纹理合成大纹理的详细步骤及代码示例

以下是使用Python和PyTorch库实现基于生成对抗网络(GAN)将小纹理合成大纹理的详细步骤及代码示例。

思路概述

我们将使用生成对抗网络(GAN)来完成小纹理到大纹理的合成任务。GAN由生成器(Generator)和判别器(Discriminator)组成。生成器的目标是生成逼真的大纹理图像,而判别器的任务是区分生成的图像和真实的大纹理图像。通过两者的对抗训练,最终生成器能够学习到如何合成高质量的大纹理图像。

代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os

# 定义数据集类
class TextureDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = self.image_files[idx]
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

# 定义生成器
class Generator(nn.Module):
    def __init__(self, z_dim, img_channels):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            self._block(z_dim, 1024, 4, 1, 0),
            self._block(1024, 512, 4, 2, 1),
            self._block(512, 256, 4, 2, 1),
            self._block(256, 128, 4, 2, 1),
            nn.ConvTranspose2d(128, img_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        )

    def forward(self, x):
        return self.gen(x)

# 定义判别器
class Discriminator(nn.Module):
    def __init__(self, img_channels):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(img_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            self._block(64, 128, 4, 2, 1),
            self._block(128, 256, 4, 2, 1),
            self._block(256, 512, 4, 2, 1),
            nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid()
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.disc(x)

# 超参数设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 0.0002
batch_size = 32
image_size = 64
z_dim = 100
img_channels = 3
num_epochs = 50

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载数据集
dataset = TextureDataset(root_dir='path/to/your/texture/images', transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 初始化生成器和判别器
gen = Generator(z_dim, img_channels).to(device)
disc = Discriminator(img_channels).to(device)

# 定义优化器和损失函数
opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))
criterion = nn.BCELoss()

# 训练循环
for epoch in range(num_epochs):
    for i, real_images in enumerate(dataloader):
        real_images = real_images.to(device)

        ### 训练判别器
        opt_disc.zero_grad()
        noise = torch.randn(batch_size, z_dim, 1, 1).to(device)
        fake_images = gen(noise)
        disc_real = disc(real_images).reshape(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake_images.detach()).reshape(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        lossD.backward()
        opt_disc.step()

        ### 训练生成器
        opt_gen.zero_grad()
        output = disc(fake_images).reshape(-1)
        lossG = criterion(output, torch.ones_like(output))
        lossG.backward()
        opt_gen.step()

    print(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {lossD.item():.4f}, Loss G: {lossG.item():.4f}")

# 生成大纹理图像
num_samples = 1
noise = torch.randn(num_samples, z_dim, 1, 1).to(device)
generated_images = gen(noise)
generated_images = (generated_images + 1) / 2  # 反归一化
generated_images = generated_images.cpu().detach().permute(0, 2, 3, 1).numpy()

# 显示生成的图像
plt.imshow(generated_images[0])
plt.axis('off')
plt.show()

代码说明

  1. 数据集类 TextureDataset:用于加载纹理图像数据集,并进行必要的预处理。
  2. 生成器 Generator:通过一系列反卷积层将随机噪声向量转换为大纹理图像。
  3. 判别器 Discriminator:使用卷积层来区分真实的大纹理图像和生成的图像。
  4. 训练循环:交替训练判别器和生成器,通过对抗训练不断提高生成器的性能。
  5. 生成大纹理图像:训练完成后,使用生成器生成大纹理图像并显示。

使用方法

  1. 将代码中的 'path/to/your/texture/images' 替换为你实际的小纹理图像文件夹路径。
  2. 确保你已经安装了PyTorch和相关的依赖库。
  3. 运行代码,等待训练完成,最后会显示生成的大纹理图像。

通过以上步骤,你就可以使用生成对抗网络将小纹理合成成大纹理图像。


http://www.kler.cn/a/587826.html

相关文章:

  • HOT100——链表篇Leetcode234. 回文链表
  • 自动化立体库的规划设计
  • [Hello-CTF]RCE-Labs超详细WP-Level13Level14(PHP下的0/1构造RCE命令简单的字数限制RCE)
  • 当内核调试过程中出现bug的调试流程
  • GEN3C:具有精确相机控制的3D信息化世界一致视频生成
  • Spring Boot使用线程池创建多线程
  • 3.3 Spring Boot多数据源动态切换:AbstractRoutingDataSource实战
  • 软件环境安装-通过Docker安装Elasticsearch和Kibana【保姆级教程、内含图解】
  • 关于深度学习参数寻优的一些介绍
  • Tcp网络通信的基本流程梳理
  • 当今前沿技术:人工智能与区块链的未来发展
  • 科大讯飞嵌入式软件开发面试总结
  • Vue与Django是如何传递参数的?
  • python-53-分别使用flask和streamlit进行向量存储和检索的服务开发实战
  • C语言中的指针与函数
  • 【PyMySQL】Python操作MySQL
  • 利用Python爬虫根据关键词获取商品列表
  • OpenHarmony 5.0 MP4封装的H265视频播放失败的解决方案
  • idea 2023社区版自动生成 serialVersionUID
  • 洛谷P11043