当前位置: 首页 > article >正文

controlnet 多 condition 融合

TL;DR

不同的条件通过加权融合

示例

# !pip install opencv-python transformers accelerate
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
import numpy as np
import torch

import cv2
from PIL import Image

# download an image
image = load_image(
    "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
)
image = np.array(image)

# get canny image
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)

# load control net and stable diffusion v1-5
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
)

# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# remove following line if xformers is not installed
pipe.enable_xformers_memory_efficient_attention()

pipe.enable_model_cpu_offload()

# generate image
generator = torch.manual_seed(0)
image = pipe(
    "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image
).images[0]

核心代码

1. 模型封装机制

if isinstance(controlnet, (list, tuple)):
    controlnet = MultiControlNetModel(controlnet)

2. controlnet 调用

down_block_res_samples, mid_block_res_sample = self.controlnet(
                    control_model_input,
                    t,
                    encoder_hidden_states=controlnet_prompt_embeds,
                    controlnet_cond=image,
                    conditioning_scale=cond_scale,
                    guess_mode=guess_mode,
                    return_dict=False,
                )

2.1. MultiControlNetModel 实现

class MultiControlNetModel(ModelMixin):
    r"""
    Multiple `ControlNetModel` wrapper class for Multi-ControlNet

    This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
    compatible with `ControlNetModel`.

    Args:
        controlnets (`List[ControlNetModel]`):
            Provides additional conditioning to the unet during the denoising process. You must set multiple
            `ControlNetModel` as a list.
    """

    def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
        super().__init__()
        self.nets = nn.ModuleList(controlnets)

    def forward(
        self,
        sample: torch.Tensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
        controlnet_cond: List[torch.tensor],
        conditioning_scale: List[float],
        class_labels: Optional[torch.Tensor] = None,
        timestep_cond: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        guess_mode: bool = False,
        return_dict: bool = True,
    ) -> Union[ControlNetOutput, Tuple]:
        for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
            down_samples, mid_sample = controlnet(
                sample=sample,
                timestep=timestep,
                encoder_hidden_states=encoder_hidden_states,
                controlnet_cond=image,
                conditioning_scale=scale,
                class_labels=class_labels,
                timestep_cond=timestep_cond,
                attention_mask=attention_mask,
                added_cond_kwargs=added_cond_kwargs,
                cross_attention_kwargs=cross_attention_kwargs,
                guess_mode=guess_mode,
                return_dict=return_dict,
            )

            # merge samples
            if i == 0:
                down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
            else:
                down_block_res_samples = [
                    samples_prev + samples_curr
                    for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
                ]
                mid_block_res_sample += mid_sample

        return down_block_res_samples, mid_block_res_sample

3.噪声预测

                # predict the noise residual
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    timestep_cond=timestep_cond,
                    cross_attention_kwargs=self.cross_attention_kwargs,
                    down_block_additional_residuals=down_block_res_samples,
                    mid_block_additional_residual=mid_block_res_sample,
                    added_cond_kwargs=added_cond_kwargs,
                    return_dict=False,
                )[0]

http://www.kler.cn/a/513047.html

相关文章:

  • Level2逐笔成交逐笔委托毫秒记录:今日分享优质股票数据20250121
  • 基于python+Django+mysql鲜花水果销售商城网站系统设计与实现
  • Chrome 132 版本新特性
  • redis性能优化参考——筑梦之路
  • Python 正则表达式
  • 第17章:Python TDD回顾与总结货币类开发
  • 网安篇(一)日志分析——从给的登录日志中找出攻击IP和使用的用户名
  • 数据结构学习记录-树和二叉树
  • 堆的实现(C语言详解版)
  • yolo系列模型为什么坚持使用CNN网络?
  • LeetCode:37. 解数独
  • [Easy] leetcode-500 键盘行
  • Pix2Pix:图像到图像转换的条件生成对抗网络深度解析
  • 实现一个自己的spring-boot-starter,基于SQL生成HTTP接口
  • 分布式系统通信解决方案:Netty 与 Protobuf 高效应用
  • 如何打造高效同城O2O平台?外卖跑腿系统源码选型与开发指南
  • 新能源工厂如何借助防静电手环监控系统保障生产安全
  • 0基础跟德姆(dom)一起学AI 自然语言处理19-输出部分实现
  • .NET Core 中如何构建一个弹性HTTP 请求机制
  • Linux应用编程(五)USB应用开发-libusb库
  • 力扣-数组-350 两个数组的交集Ⅱ
  • 连接池偶现15分钟超时问题
  • 数组-二分查找
  • qt中透明度表示
  • 如何使用 Python 进行文件读写操作?
  • 【Linux】Socket编程-TCP构建自己的C++服务器