当前位置: 首页 > article >正文

Python使用AI animegan2-pytorch制作属于你的漫画头像/风景图片

Python使用AI animegan2-pytorch制作属于你的漫画头像

    • 1. 效果图
    • 2. 原理
    • 3. 源码
    • 参考

git clone https://github.com/bryandlee/animegan2-pytorch
cd ./animegan2-pytorch
python test.py --photo_path images/photo_test.jpg --save_path images/animegan2_result.png

1. 效果图

官方效果图如下:

效果图v2 512模型如下:
在这里插入图片描述

效果图v1 512模型如下:
在这里插入图片描述

效果图v1 效果不太好如下:
在这里插入图片描述

效果图rece如下
人物会有一种病态的美,过于白了,风景上效果更好一些;
人物与photo2cartoon的效果图有点像;
在这里插入图片描述

在这里插入图片描述

效果图paprika 模型如下
人物纹理痕迹太过明显,更适合风景
下一张明兰的效果还不错,不同的模型在不同的图像上也会有些微差别;
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

origin vs v1Res vs v2Res vs paprikaRes vs celedistillResAll 风景效果对比图如下:

在这里插入图片描述
在这里插入图片描述

origin vs v1Res vs v2Res vs paprikaRes vs celedistillResAll 人物效果对比图如下:
在这里插入图片描述
在这里插入图片描述

2. 原理

人像/风景卡通风格渲染的目标是,在保持原图像 ID 信息和纹理细节的同时,将真实照片转换为卡通风格的非真实感图像。

3. 源码

源码及示例文件模型等见资源:https://download.csdn.net/download/qq_40985985/87739198

# animegan2-pytroch 生成漫画头像或者风景图
# python test.py --checkpoint weights/face_paint_512_v2.pt --input_dir samples/faces/ --device cpu --output_dir samples/resv2
# model loaded: weights/face_paint_512_v2.pt

import os
import argparse

from PIL import Image
import numpy as np

import torch
from torchvision.transforms.functional import to_tensor, to_pil_image

from model import Generator


torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


def load_image(image_path, x32=False):
    img = Image.open(image_path).convert("RGB")

    if x32:
        def to_32s(x):
            return 256 if x < 256 else x - x % 32
        w, h = img.size
        img = img.resize((to_32s(w), to_32s(h)))

    return img


def test(args):
    device = args.device
    
    net = Generator()
    net.load_state_dict(torch.load(args.checkpoint, map_location="cpu"))
    net.to(device).eval()
    print(f"model loaded: {args.checkpoint}")
    
    os.makedirs(args.output_dir, exist_ok=True)

    for image_name in sorted(os.listdir(args.input_dir)):
        if os.path.splitext(image_name)[-1].lower() not in [".jpg", ".png", ".bmp", ".tiff"]:
            continue
            
        image = load_image(os.path.join(args.input_dir, image_name), args.x32)

        with torch.no_grad():
            image = to_tensor(image).unsqueeze(0) * 2 - 1
            out = net(image.to(device), args.upsample_align).cpu()
            out = out.squeeze(0).clip(-1, 1) * 0.5 + 0.5
            out = to_pil_image(out)

        out.save(os.path.join(args.output_dir, image_name))
        print(f"image saved: {image_name}")


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--checkpoint',
        type=str,
        default='./weights/paprika.pt',
    )
    parser.add_argument(
        '--input_dir', 
        type=str, 
        default='./samples/inputs',
    )
    parser.add_argument(
        '--output_dir', 
        type=str, 
        default='./samples/results',
    )
    parser.add_argument(
        '--device',
        type=str,
        default='cuda:0',
    )
    parser.add_argument(
        '--upsample_align',
        type=bool,
        default=False,
        help="Align corners in decoder upsampling layers"
    )
    parser.add_argument(
        '--x32',
        action="store_true",
        help="Resize images to multiple of 32"
    )
    args = parser.parse_args()
    
    test(args)

# 原图VS效果图绘制
# python plot_sample.py

# 获取输入路径的所有图像
import cv2
import imutils
import numpy as np
from imutils import paths

imagePaths = sorted(list(paths.list_images("samples")))

list = [x for x in imagePaths if x.find('inputs') > 0]
print(list)

resv1 = [x for x in imagePaths if x.find("resv1") > 0]
resv2 = [x for x in imagePaths if x.find("resv2") > 0]
cele = [x for x in imagePaths if x.find("cele") > 0]
pap = [x for x in imagePaths if x.find("paprika") > 0]

img = None
for i in list:
    if (i.find("ml2.jpg") < 0): continue
    img = None
    for j in resv1:
        if (j.split("\\")[2].__eq__(i.split("\\")[2])):
            origin = cv2.imread(i)
            res = cv2.imread(j)
            if (origin.shape[0] != res.shape[0] or origin.shape[1] != res.shape[1]):
                res = cv2.resize(res, (origin.shape[1], origin.shape[0]))
            # print(origin.shape, res.shape)
            # print('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res')
            cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res',
                       imutils.resize(np.hstack([origin, res]), width=300))
            if (img is None):
                img = imutils.resize(np.hstack([origin, res]), width=300)
            else:
                imgA = np.vstack([img, imutils.resize(np.hstack([origin, res]), width=300)])

                img = imgA
            cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'ResAll',
                       img)
            # cv2.waitKey(0)
    for j in resv2:
        if (j.split("\\")[2].__eq__(i.split("\\")[2])):
            origin = cv2.imread(i)
            res = cv2.imread(j)
            if (origin.shape[0] != res.shape[0] or origin.shape[1] != res.shape[1]):
                res = cv2.resize(res, (origin.shape[1], origin.shape[0]))
            # cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res',
            #            imutils.resize(np.hstack([origin, res]), width=300))
            if (img is None):
                img = imutils.resize(np.hstack([origin, res]), width=300)
            else:
                imgA = np.vstack([img, imutils.resize(np.hstack([origin, res]), width=300)])

                img = imgA
            # cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'ResAll',
            #            img)
            # cv2.waitKey(0)
    for j in pap:
        if (j.split("\\")[2].__eq__(i.split("\\")[2])):
            # print('--------------\t', i, j)
            origin = cv2.imread(i)
            res = cv2.imread(j)
            if (origin.shape[0] != res.shape[0] or origin.shape[1] != res.shape[1]):
                res = cv2.resize(res, (origin.shape[1], origin.shape[0]))
            # print(origin.shape, res.shape)
            # print('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res')
            # cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res',
            #            imutils.resize(np.hstack([origin, res]), width=300))
            # list.append(imutils.resize(np.hstack([origin, res]), width=300))
            if (img is None):
                img = imutils.resize(np.hstack([origin, res]), width=300)
            else:
                imgA = np.vstack([img, imutils.resize(np.hstack([origin, res]), width=300)])

                img = imgA
            # cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'ResAll',
            #            img)
            # cv2.waitKey(0)
    for j in cele:
        if (j.split("\\")[2].__eq__(i.split("\\")[2])):
            # print('--------------\t', i, j)
            origin = cv2.imread(i)
            res = cv2.imread(j)
            if (origin.shape[0] != res.shape[0] or origin.shape[1] != res.shape[1]):
                res = cv2.resize(res, (origin.shape[1], origin.shape[0]))
            # print(origin.shape, res.shape)
            # print('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res')
            # cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res',
            #            imutils.resize(np.hstack([origin, res]), width=300))
            # list.append(imutils.resize(np.hstack([origin, res]), width=300))
            if (img is None):
                img = imutils.resize(np.hstack([origin, res]), width=300)
            else:
                imgA = np.vstack([img, imutils.resize(np.hstack([origin, res]), width=300)])

                img = imgA
            cv2.imshow('origin vs v1Res vs v2Res vs paprikaRes vs celedistillResAll',
                       img)
            cv2.waitKey(0)

参考

  • https://alltodata.blog.csdn.net/article/details/125183830
  • https://github.com/bryandlee/animegan2-pytorch

http://www.kler.cn/news/16599.html

相关文章:

  • 3.3 泰勒公式例题分析
  • c++ 11标准模板(STL) std::vector (三)
  • 同时使用注解和 xml 的方式引用 dubbo 服务产生的异常问题排查实战
  • 抓马,互联网惊现AI鬼城:上万个AI发帖聊天,互相嗨聊,人类被禁言
  • ASIC-WORLD Verilog(6)运算符
  • 【.net core 自动生成数据库】
  • 认识Cookie和Session
  • 【算法】求最短路径算法
  • react之按钮鉴权
  • Java微服务商城高并发秒杀项目--013.SentinelResource的使用
  • 算法刷题|392.判断子序列、115.不同的子序列
  • 大型医院影像PACS系统三维重建技术(获取数据、预处理、配准、重建和可视化)
  • Stable Diffusion 本地部署教程不完全指南
  • 第18章 项目风险管理
  • javascript中的严格模式
  • 自动驾驶行业观察之2023上海车展-----车企发展趋势(1)
  • Python基础合集 练习19(类与对象3(多态))
  • Chapter4:频率响应法(上)
  • Linux套接字编程-2
  • Packet Tracer - 静态路由故障排除
  • 如何学习python?
  • 【C++】右值引用完美转发
  • 什么是 Docker?它能用来做什么?
  • ChatGPT常见问题及其解决方法汇总
  • 微软正式宣布 Win10 死刑,Win11 LTSC要来了
  • 使用 ESP32 设计智能手表 – 第 1 部分制作表盘
  • Shell编程之循环语句
  • osg操控器之动画路径操控器osgGA::AnimationPathManipulator分析
  • 代码随想录算法训练营第四十五天|70. 爬楼梯 (进阶)、322. 零钱兑换、279.完全平方数
  • MySQL基本操作