当前位置: 首页 > article >正文

position embedding

文章目录

  • 1. 四种position embedding
  • 2. pytorch 源码[后续整理]

【因比较忙,后续整理】

1. 四种position embedding

Position Embedding
1. Transformer
1.1 1d absolute
1.2 sin/cos constant
1.3
2. Vision Transformer
2.1 1d absolute
2.2 trainable
3. Swin Transformer
3.1 2d relative bias
3.2 trainable
4. Masked AutoEncoder
4.1 2d absolute
4.2 sin/cos constant

2. pytorch 源码[后续整理]

import torch
import torch.nn as nn

torch.set_printoptions(precision=3, sci_mode=False)


# ---------------------------------------------------------------------------------
# transformer constant sin/cos embedding position
def create_1d_absolute_sincos_embeddings(n_pos_vec, dim):
    # n_pos_vec : torch.arange(n_pos,dtype=torch.float)
    assert dim % 2 == 0, "wrong dimension"
    position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float)

    omega = torch.arange(dim // 2, dtype=torch.float)
    omega /= dim / 2.0
    omega = 1.0 / (10000 ** omega)

    out = n_pos_vec[:, None] @ omega[None, :]

    emb_sin = torch.sin(out)
    emb_cos = torch.cos(out)

    position_embedding[:, 0::2] = emb_sin
    position_embedding[:, 1::2] = emb_cos

    return position_embedding


# ---------------------------------------------------------------------------------

# ---------------------------------------------------------------------------------
# 2. 1d absolute trainable embedding
def create_1d_absolute_trainable_embeddings(n_pos_vec, dim):
    # n_pos_vec : torch.arange(n_pos,dtype=torch.float)
    position_embedding = nn.Embedding(n_pos_vec.numel(), dim)
    nn.init.constant_(position_embedding.weight, 0.0)
    return position_embedding


# 3. 2d relative bias trainable embedding
def create_2d_relative_bias_trainable_embeddings(n_heads, height, width, dim):
    # width=5,-->torch.arange(5)=[0,1,2,3,4]--> bias=[-4,-3,-2,-1,0,1,2,3,4]=2*width-1
    # width=5,-->torch.arange(5)=[0,1,2,3,4]--> bias=[-4,-3,-2,-1,0,1,2,3,4]=2*width-1
    ps_height = (2 * height - 1) * (2 * width - 1)
    ps_width = n_heads
    position_embedding = nn.Embedding(ps_height, ps_width)
    nn.init.constant_(position_embedding.weight, 0.0)

    def get_relative_position_index(height, width):
        coords = torch.stack(torch.meshgrid(torch.arange(height), torch.arange(width)))  # [2,height,width]
        coords_flatten = torch.flatten(coords, 1)  # [2,height*width]
        relative_coords_bias = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2,height*width,height*width]
        relative_coords_bias[0, :, :] += height - 1
        relative_coords_bias[1, :, :] += width - 1

        # A:2d,B:1d,B[i*cos+j] = A[i,j]
        relative_coords_bias[0, :, :] *= relative_coords_bias[1, :, :].max() + 1
        return relative_coords_bias.sum(0)  # [height*width,height*width]

    relative_position_bias = get_relative_position_index(height, width)
    bias_embedding = position_embedding(torch.flatten(relative_position_bias)).reshape(height * width, height * width,
                                                                                       n_heads)
    bias_embedding = bias_embedding.permute(2, 0, 1).unsqueeze(0)
    return bias_embedding


if __name__ == "__main__":
    run_code = 0
    n_pos = 4
    dim = 4
    n_pos_vec = torch.arange(n_pos, dtype=torch.float)
    pe = create_1d_absolute_sincos_embeddings(n_pos_vec, dim)
    print(f"pe=\n{pe}")
    my_n_heads = 3
    my_height = 4
    my_width = 5
    my_dim = 6

    result = create_2d_relative_bias_trainable_embeddings(n_heads=my_n_heads, height=my_height, width=my_width,
                                                          dim=my_dim)
    print(f"result=\n{result}")

http://www.kler.cn/a/613504.html

相关文章:

  • 近场探头的选型
  • uni-app自动升级功能
  • 5.0 WPF的基础介绍1-Grid,Stack,button
  • 如何编写单元测试
  • jenkins批量复制视图项目到新的视图
  • 关于笔记本电脑突然没有wifi图标解决方案
  • 口腔种植全流程AI导航系统及辅助诊疗与耗材智能化编程分析
  • 代理IP协议详解HTTP、HTTPS、SOCKS5分别适用于哪些场景
  • 大模型在支气管扩张预测及治疗方案制定中的应用研究
  • Windows 图形显示驱动开发-WDDM 2.4功能-GPU 半虚拟化(八)
  • 小迪安全-php模型,mvc架构,动态调试未授权,脆弱及安全,为引用。逻辑错误
  • 计算机三级网络技术大题总结
  • QT计算器开发
  • DeepSeek R1与V3:混合架构下的推理革命与效率破局
  • 特仑苏首发牛奶人文纪录片!如何借势营销重构品牌护城河?
  • SpringBoot项目中,controller 、 entity、mapper和service包的介绍
  • 4、网工软考—VLAN配置—hybird配置
  • 【C++】模拟实现一颗二叉搜索树
  • LeeCode 434. 字符串中的单词数
  • MySQL(数据表创建)