当前位置: 首页 > article >正文

[DL]深度学习_扩散模型正弦时间编码

1  扩散模型时间步嵌入

1.1  时间步正弦编码

        在扩散模型按时间步 t 进行加噪去噪过程时,需要包括反映噪声水平的时间步长 t 作为噪声预测器的额外输入。但是最初与图像配套的时间步 t 是数字,需要将代表时间步 t 的数字编码为向量嵌入。嵌入时间向量的宽度dim是按照输出设定的。

def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = th.exp(
        -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
    if dim % 2:
        embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

参数定义

  • timesteps--时间步 t ,包含batch中每个元素的时间步信息,是一个一维张量,形状为[N];
  • dim--编码后时间嵌入的最后一维大小,即输出多宽的向量;
  • max_period--控制嵌入的最小频率,值为10000。 

        假设此时的batch中有4个图像,则此时4个图像有4个对应的扩散时间步,所需编码的时间嵌入宽度是8。timesteps的形状为[4] 。

timesteps = [t1, t2, t3, t4], dim = 8

编码过程

    half = dim // 2

        计算所需输出向量的一半尺寸,以供后续将正弦嵌入和余弦嵌入按照最后维度concat拼接。

    freqs = th.exp(
        -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
    ).to(device=timesteps.device)

        用了固定化编码方式,计算出的频率向量,且该频率向量的维度是由dim的一半half决定的。arange(start=0,end=half)则代表arange(0,4),若以i为代表,则此时 i = 0,1,2,3。公式如下:

freqs=e^{(-\frac{i\times log(maxperiod)}{half})} 

则此时计算出的 freqs = [freqs1,freqs2,freqs3,freqs4],是一个一维张量,形状为[4]。

    args = timesteps[:, None].float() * freqs[None]

        首先,timesteps[:, None]作用是将timesteps的形状从一维张量[4]扩展为2维张量[4x1],之前是 timesteps = [ t1, t2, t3, t4 ] 扩展为 timesteps = [ [ t1 ], [ t2 ], [ t3 ], [ t4 ] ]。

        其次将freqs的形状从一维张量[4]扩展为2维张量[1x4],之前是 freqs = [freqs1,freqs2,freqs3,freqs4] 扩展为 freqs = [ [freqs1,freqs2,freqs3,freqs4] ]。

        这样做的目的是为了将时间步 timesteps 代表的batch内每个图片的时间步与频率 freqs 做乘法,将时间步广播进去。得出args的形状是 [4x4] ,即 [N x half] ,每个时间步对应得到half列,有N个时间步。

    embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)

        将得到的args求cos值和sin值,此时得到的 cos(args) 与 sin(args) 的形状依旧都是 [N x half],按照最后一维拼接 cos(args) 与 sin(args) ,则拼接后得到的embedding的形状是 [4 x 8], [N x 2half] ,即 [N x dim]。对应开始定义的 dim 代表的是编码后的最后维度大小。 

    if dim % 2:
        embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

        最后加一个判断,当 dim % 2 ==1 即dim是奇数时,embedding的形状是 [N x dim-1],需要用零向量补充1个维度,变为 [N x dim] 。

1.2  馈入多层感知机 

        当对时间步长执行正弦编码之后,需要将其馈送到多层感知机中获得隐式时间嵌入。 

        time_embed_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

        多层感知机由两个具有倒置瓶颈结构的线性层和一个SiLU激活函数组成。其中dim = channels = 8,time_embed_dim = model_channels * 4 = 32。需要注意是 N 为 batch 大小。

  • linear(model_channels, time_embed_dim)中,输入形状为 [4 x 8] ,[N x dim] 即 [N x model_channels] ,输出形状为 [4 x 32] , [N x time_embed_dim]。
  • 激活函数不会改变形状。
  • linear(time_embed_dim, time_embed_dim)中,输入形状为 [4 x 32] ,[N x time_embed_dim],输出形状为 [4 x 32],[N x time_embed_dim]。

        多层感知机主要作用是将输入的时间步正弦编码向量变得更宽。通过激活函数引入非线性。 


http://www.kler.cn/a/412703.html

相关文章:

  • spring boot框架漏洞复现
  • k8s 架构详解
  • 解决首次加载数据空指针异常
  • [代码随想录Day24打卡] 93.复原IP地址 78.子集 90.子集II
  • wordpress二开-WordPress新增页面模板-说说微语
  • fastjson不出网打法—BCEL链
  • 关于如何在k8s中搭建一个nsfw黄图鉴定模型
  • Spring |(四)IoC/DI配置管理第三方bean
  • NLP 2、机器学习简介
  • Dart 中 initializer lists
  • 【MySQL】自动刷新flush privileges命令
  • 【技术文档:技术传播的灯塔】
  • Python学习——猜拳小游戏
  • 组会 | 大语言模型 + LoRA
  • chrome允许http网站打开摄像头和麦克风
  • C++优质学习资源汇总
  • 【开源项目】ChinaAddressCrawler 中国行政区划数据(1980-2023年)采集及转换(Java版),含SQL格式及JSON格式
  • python+django自动化部署日志采用‌WebSocket前端实时展示
  • 第76期 | GPTSecurity周报
  • 23种设计模式-生成器(Builder)设计模式
  • 解决首次加载数据空指针异常
  • scp比rz sz传文件更好
  • 海康VsionMaster学习笔记(学习工具+思路)
  • faiss库中ivf-sq(ScalarQuantizer,标量量化)代码解读-6
  • C#面向对象,封装、继承、多态、委托与事件实例
  • Linux环境实现c语言编程