当前位置: 首页 > article >正文

使用PyTorch进行图像风格迁移:基于VGG19实现

图像风格迁移(Neural Style Transfer, NST)是深度学习中一个令人着迷的应用,它能够将一张图像的风格应用到另一张图像上。例如,能够将梵高的画风应用到一张普通照片上。本文将详细解释如何使用PyTorch进行风格迁移,逐步分析代码,并讲解其中的关键技术。

1. 环境准备

在开始之前,确保安装了必要的库:

pip install torch torchvision pillow

2. 模型缓存目录设置

为了加速模型的加载,我们可以通过设置环境变量TORCH_HOME来指定模型缓存目录,避免每次运行代码时重新下载模型:

os.environ['TORCH_HOME'] = './model_directory'  # 你可以根据需要自定义目录

3. 加载图像

加载图像并进行预处理是风格迁移中的重要步骤。我们需要将图像转换为张量并进行归一化处理,以便与预训练的VGG19模型匹配:

def load_image(image_path, max_size=400):
    image = Image.open(image_path).convert('RGB')
    size = min(max_size, max(image.size))

    transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    return transform(image).unsqueeze(0)

在这里,我们将图像调整为不大于400像素的正方形,并将其转换为适合VGG19模型输入的格式。

4. VGG19模型的特征提取

风格迁移的核心思想是将内容图像的高层次特征与风格图像的低层次特征结合。我们使用VGG19模型的前21层来提取图像的特征:

class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        self.features = vgg19(pretrained=True).features[:21].eval()

    def forward(self, x):
        features = []
        for i, layer in enumerate(self.features):
            x = layer(x)
            if i in {0, 5, 10, 19, 21}:
                features.append(x)
        return features

5. 内容与风格损失

内容损失衡量生成图像与内容图像的特征差异,而风格损失则是基于Gram矩阵来衡量生成图像与风格图像的差异。

  • 内容损失:
class ContentLoss(nn.Module):
    def __init__(self, target):
        super(ContentLoss, self).__init__()
        self.target = target.detach()

    def forward(self, input):
        return nn.functional.mse_loss(input, self.target)
  • 风格损失:
class StyleLoss(nn.Module):
    def __init__(self, target):
        super(StyleLoss, self).__init__()
        self.target = self.gram_matrix(target).detach()

    def gram_matrix(self, input):
        batch_size, channels, height, width = input.size()
        features = input.view(batch_size * channels, height * width)
        G = torch.mm(features, features.t())
        return G.div(batch_size * channels * height * width)

    def forward(self, input):
        G = self.gram_matrix(input)
        return nn.functional.mse_loss(G, self.target)

6. 图像风格迁移算法

核心算法将内容图像初始化为输入图像,并通过多次迭代优化,使其逐步接近目标风格图像,同时保持内容的完整性。我们使用LBFGS优化器来实现这一过程:

def style_transfer(content_img, style_img, num_steps=1000, style_weight=1e9, content_weight=1):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    content_img = content_img.to(device)
    style_img = style_img.to(device)

    model = VGG().to(device)

    style_features = model(style_img)
    content_features = model(content_img)

    input_img = content_img.clone().requires_grad_(True).to(device)

    optimizer = optim.LBFGS([input_img])

    style_losses = []
    content_losses = []

    for sf, cf in zip(style_features, content_features):
        content_losses.append(ContentLoss(cf))
        style_losses.append(StyleLoss(sf))

    run = [0]
    while run[0] <= num_steps:

        def closure():
            optimizer.zero_grad()

            input_features = model(input_img)
            content_loss = 0
            style_loss = 0

            for cl, input_f in zip(content_losses, input_features):
                content_loss += content_weight * cl(input_f)

            for sl, input_f in zip(style_losses, input_features):
                style_loss += style_weight * sl(input_f)

            loss = content_loss + style_loss
            loss.backward()

            run[0] += 1
            if run[0] % 50 == 0:
                print(f'Step {run[0]}, Content Loss: {content_loss.item():4f}, Style Loss: {style_loss.item():4f}')

            return loss

        optimizer.step(closure)

    return input_img

7. 结果保存

生成的图像需要去除归一化并保存为常规图片格式:

def save_image(tensor, path):
    image = tensor.clone().detach()
    image = image.squeeze(0)
    image = transforms.ToPILImage()(image)
    image.save(path)

8. 主函数执行

整个过程可以通过主函数来执行,加载图像、进行风格迁移并保存结果:

if __name__ == '__main__':
    content_image_path = 'content_image.png'
    style_image_path = 'style_image.png'
    output_image_path = 'output_image.jpg'

    content_img = load_image(content_image_path)
    style_img = load_image(style_image_path)

    result = style_transfer(content_img, style_img)

    save_image(result, output_image_path)
    print(f"风格迁移完成,图像已保存为 {output_image_path}")

总结

本文展示了如何使用PyTorch和VGG19模型实现图像风格迁移。通过合理设置内容和风格损失的权重,我们可以生成既保留内容图像结构又具有风格图像艺术风格的全新图像。

完整代码

github:https://github.com/Yolumia/Image_style_transfer_base_vgg19/


http://www.kler.cn/news/304480.html

相关文章:

  • Git工作流程
  • Nacos 与 Eureka 的区别
  • vue3 使用 codemirror 实现yaml文件的在线编辑
  • 点餐|基于java的电子点餐系统小程序(源码+数据库+文档)
  • 【Excel 表打印基本操作】
  • 【图像识别】摄像头捕捉运动到静止视频帧(免费源码分享)
  • Maven从入门到精通(三)
  • IVF 视频文件格式
  • [网络][CISCO]CISCO_华为网络设备端口镜像配置
  • Cache Aside pattern
  • EG边缘计算网关连接纵横云3.0物联网平台(MQTT协议)
  • Notepad++插件:TextFX 去除重复行
  • 快速理解Redis
  • 【系统规划与管理师】【案例分析】【考点】【问题篇】第5章 IT服务部署实施
  • MiniCPM-V: A GPT-4V Level MLLM on Your Phone
  • Ansys HFSS的边界条件与激励端口
  • 【Linux入门】iptables的安装与配置应用实例
  • pg \d 在不同模式下有同名表时注意事项
  • 828华为云征文|华为云Flexus X实例docker部署Jitsi构建属于自己的音视频会议系统
  • 软件工程毕业设计开题汇总
  • 如何在 DigitalOcean Droplet 云服务器上部署 Next.js 应用
  • 技术周刊 | Vue3.5、Replit Agent、Cursor 使用技巧、React 19 中的新功能、8 月 Web 平台的新功能
  • 9.11 QT ( Day 4)
  • oracle数据库安装和配置详细讲解
  • 个人学习笔记6-2:动手学深度学习pytorch版-李沐
  • Qt使用UDP进行单波通信
  • 实习项目|苍穹外卖|day9
  • mmseqs2进行pdb蛋白质序列聚类分析
  • 在云服务器上安装 RabbitMQ:从零到一的最佳实践
  • 机械学习—零基础学习日志(Python做数据分析03)