当前位置: 首页 > article >正文

【扩散模型(六)】IP-Adapter 是如何训练的?2 源码篇(IP-Adapter Plus)

系列文章目录

  • 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
  • 【扩散模型(三)】IP-Adapter 源码详解1-训练输入 介绍了训练代码中的 image prompt 的输入部分,即 img projection 模块。
  • 【扩散模型(四)】IP-Adapter 源码详解2-训练核心(cross-attention)详细介绍 IP-Adapter 训练代码的核心部分,即插入 Unet 中的、针对 Image prompt 的 cross-attention 模块。
  • 【扩散模型(五)】IP-Adapter 源码详解3-推理代码 详细介绍 IP-Adapter 推理过程代码。
  • 【可控图像生成系列论文(四)】IP-Adapter 具体是如何训练的?1公式篇
  • 本文则以 IP-Adapter Plus 训练代码为例,进行详细介绍。

文章目录

  • 系列文章目录
  • 整体训练框架
  • 一、训了哪些部分?
      • 第一块 - image_proj_model
      • 第二块 - adapter_modules
  • 二、训练目标


整体训练框架

在这里插入图片描述

一、训了哪些部分?

本文以原仓库 1 的 /path/IP-Adapter/tutorial_train_plus.py 为例,该文件为 SD1.5 IP-Adapter Plus 的训练代码。

从以下代码可以看出,IPAdapter 主要由 unet, image_proj_model, adapter_modules 3 个部分组成,而权重需要被优化的(训练到的)只有 ip_adapter.image_proj_model.parameters(), 和 ip_adapter.adapter_modules.parameters() 。

	ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)

    # optimizer
    params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(),  ip_adapter.adapter_modules.parameters())
    optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
    ...
    # Prepare everything with our `accelerator`.
    ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)

第一块 - image_proj_model

在 IP-Adapter Plus 中,采用的是 Resampler 作为img embedding 到 ip_tokens 的映射网络,对图像(image prompt)中信息的抽取更加细粒度。其他模块都不需要梯度下降,如下代码所示。

	# freeze parameters of models to save more memory
    unet.requires_grad_(False)
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    image_encoder.requires_grad_(False)
    
    #ip-adapter-plus
    image_proj_model = Resampler(
        dim=unet.config.cross_attention_dim,
        depth=4,
        dim_head=64,
        heads=12,
        num_queries=args.num_tokens,
        embedding_dim=image_encoder.config.hidden_size,
        output_dim=unet.config.cross_attention_dim,
        ff_mult=4
    )
    ...

第二块 - adapter_modules

Decoupled cross-attention 则在以下代码中进行初始化,关键是在特定的 unet 层中进行替换,详细位置可以参考前文中的图片,本文的重点是后续训练的实现。

	# init adapter modules
    attn_procs = {}
    unet_sd = unet.state_dict()
    for name in unet.attn_processors.keys():
        cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
        if name.startswith("mid_block"):
            hidden_size = unet.config.block_out_channels[-1]
        elif name.startswith("up_blocks"):
            block_id = int(name[len("up_blocks.")])
            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
        elif name.startswith("down_blocks"):
            block_id = int(name[len("down_blocks.")])
            hidden_size = unet.config.block_out_channels[block_id]
        if cross_attention_dim is None:
            attn_procs[name] = AttnProcessor()
        else:
            layer_name = name.split(".processor")[0]
            weights = {
                "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
                "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
            }
            attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=args.num_tokens)
            attn_procs[name].load_state_dict(weights)
    unet.set_attn_processor(attn_procs)
    adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())

二、训练目标

每个 epoch 是遍历完一整个 dataset,我们直接从每个训练步的循环中来看:

  • latents 是通过 vae 将输入的 image prompt 压到了隐空间(latent space)中。
  • 准备相应的 noise 和 timesteps ,再通过 noise_scheduler 来制作出 noisy_latents。
        for step, batch in enumerate(train_dataloader):
            load_data_time = time.perf_counter() - begin
            with accelerator.accumulate(ip_adapter):
                # Convert images to latent space
                with torch.no_grad():
                    latents = vae.encode(batch["images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample()
                    latents = latents * vae.config.scaling_factor

                # Sample noise that we'll add to the latents
                noise = torch.randn_like(latents)
                bsz = latents.shape[0]
                # Sample a random timestep for each image
                timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
                timesteps = timesteps.long()

                # Add noise to the latents according to the noise magnitude at each timestep
                # (this is the forward diffusion process)
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
  • clip_images 和 drop_image_embed 是在准备数据的过程中,做了一个随机 drop 的方式进行数据增强,提升模型鲁棒性。
    • 数据增强:通过随机丢弃一些图像,模型被迫学习从剩余的图像中提取信息,这可以增加模型的泛化能力。
    • 模型鲁棒性:训练模型以处理不完整的数据,使其在实际应用中对缺失数据更加鲁棒。
     clip_images = []
     for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]):
         if drop_image_embed == 1:
             clip_images.append(torch.zeros_like(clip_image))
         else:
             clip_images.append(clip_image)
     clip_images = torch.stack(clip_images, dim=0)
     with torch.no_grad():
         image_embeds = image_encoder(clip_images.to(accelerator.device, dtype=weight_dtype), output_hidden_states=True).hidden_states[-2]
 
     with torch.no_grad():
         encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]

  1. https://github.com/tencent-ailab/IP-Adapter/tree/main ↩︎


http://www.kler.cn/a/285950.html

相关文章:

  • 使用PostgreSQL的CLI客户端查询数据不显示问题
  • 计算机网络概述(协议层次与服务模型)
  • 爬虫引流推广使用IP
  • 公务员面试(c语言)
  • 【网络基础】探讨以太网:封装解包、MTU、MAC地址与碰撞
  • XR虚拟拍摄短剧 | 探索虚拟制作在短剧领域的应用与发展
  • 若依微服务ruoyi-auth在knife4j中不显示问题解决
  • .net dataexcel winform控件 更新 日志
  • JavaFX基本控件-Label
  • 概率论原理精解【10】
  • vscode c++和cuda开发环境配置
  • stm32-SD卡实验
  • 内存区域与内存溢出异常
  • flutter使用echarts
  • 优惠券的最佳利用策略:如何在Java代码中优化优惠券的使用
  • SpringSecurity Oauth2 - 密码模式完成身份认证获取令牌 [自定义UserDetailsService]
  • 9千含读音文件的中文汉语学习ACCESS\EXCEL数据库
  • 《JavaEE进阶》----6.<SpringMVC实践项目:【简易两整数加法计算器】>
  • 网络安全总结②
  • 远程调用以及注册中心Nacos