当前位置: 首页 > article >正文

【扩散模型(十)】IP-Adapter 源码详解 4 - 训练细节、具体训了哪些层?

系列文章目录

  • 【扩散模型(一)】中介绍了 Stable Diffusion 可以被理解为重建分支(reconstruction branch)和条件分支(condition branch)
  • 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
  • 【扩散模型(三)】IP-Adapter 源码详解1-训练输入 介绍了训练代码中的 image prompt 的输入部分,即 img projection 模块。
  • 【扩散模型(四)】IP-Adapter 源码详解2-训练核心(cross-attention)详细介绍 IP-Adapter 训练代码的核心部分,即插入 Unet 中的、针对 Image prompt 的 cross-attention 模块。
  • 【扩散模型(五)】IP-Adapter 源码详解3-推理代码 详细介绍 IP-Adapter 推理过程代码。
  • 【可控图像生成系列论文(四)】IP-Adapter 具体是如何训练的?1公式篇
  • 【扩散模型(六)】IP-Adapter 是如何训练的?2 源码篇(IP-Adapter Plus)
  • 【扩散模型(九)】IP-Adapter 与 IP-Adapter Plus 的具体区别是什么?

文章目录

  • 系列文章目录
    • adapter_modules 分为两类
  • 总结


通过前面的系列文章,很清楚要训练的就是 image_proj_model(或者对于 plus 来说是 resampler) 和 adapter_modules 两块。

而 image_proj_model 这块比较简单,原码如下所示

    # freeze parameters of models to save more memory
    unet.requires_grad_(False)
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    image_encoder.requires_grad_(False)
    
    #ip-adapter
    image_proj_model = ImageProjModel(
        cross_attention_dim=unet.config.cross_attention_dim,
        clip_embeddings_dim=image_encoder.config.projection_dim,
        clip_extra_context_tokens=4,
    )

adapter_modules 分为两类

  1. AttnProcessor 对应 self attention
  2. IPAttnProcessor 对应 cross attention

按理说 self attention 对应的 AttnProcessor 应该不会被训练,但是 training = True,便让人非常费解。
在这里插入图片描述
进一步查看 AttnProcessor2_0 和 IPAttnProcessor2_0 后,就清楚了,因为从 AttnProcessor2_0 的构造函数(init)中并没有参数,就算是 trianing = True 也并不影响训练,实际训练的模块还是 IPAttnProcessor2_0 构造函数中的 to_k_ip 和 to_v_ip 两层 linear!

class AttnProcessor2_0(torch.nn.Module):
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def __call__(
...

class IPAttnProcessor2_0(torch.nn.Module):
    r"""
    Attention processor for IP-Adapater for PyTorch 2.0.
    Args:
        hidden_size (`int`):
            The hidden size of the attention layer.
        cross_attention_dim (`int`):
            The number of channels in the `encoder_hidden_states`.
        scale (`float`, defaults to 1.0):
            the weight scale of image prompt.
        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
            The context length of the image features.
    """

    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
        super().__init__()

        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.scale = scale
        self.num_tokens = num_tokens

        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)

    def __call__(

总结

  1. IP-Adapter 训的就是 image_proj_model(或者对于 plus 来说是 resampler) 和 adapter_modules 两块。
  2. 在 adapter_modules 中,实际只训了 IPAttnProcessor2_0 的 to_k_ip 和 to_v_ip。
  3. adapter_modules 是在每个有含有 cross attention 的 unet block 里进行的替换,如下图所示。

在这里插入图片描述


http://www.kler.cn/a/294172.html

相关文章:

  • 【问卷调研】HarmonyOS SDK开发者社区用户需求有奖调研
  • Android 13 实现屏幕熄屏一段时候后关闭 Wi-Fi 和清空多任务列表
  • 火车车厢重排问题,C++详解
  • 【面试题】发起一次网络请求,当请求>=1s,立马中断
  • 【Golang】Channel的ring buffer实现
  • ArkTs简单入门案例:简单的图片切换应用界面
  • 新加坡裸机云多IP服务器特性
  • java-在idea中antrl的hello world
  • 63、Python之函数高级:装饰器缓存实战,优化递归函数的性能
  • Spring Boot启动卡在Root WebApplicationContext: initialization completed in...
  • TulingMember进销存系统
  • Save OpenAI response in Azure function to Blob storage
  • 简单上手 PIPENV
  • 2024高教社杯数学建模国赛ABCDE题选题建议+初步分析
  • 计算机网络-VRRP工作原理
  • kubelet 探针
  • Vue3:实现路径变量
  • 同时播放多个视频
  • Spring Cloud Gateway整合基于STOMP协议的WebSocket实战及遇到问题解决
  • 基于单片机的家居环境监测系统的设计
  • 项目7-音乐播放器7(测试报告)
  • MATLAB 中的矩阵拼接技巧
  • bash反弹shell分析
  • C#编程语言及.NET 平台快速入门指南
  • Facebook群控系统,零门槛营销
  • 基于人工智能的聊天情感分析系统