position embedding
文章目录
- 1. 四种position embedding
- 2. pytorch 源码[后续整理]
【因比较忙,后续整理】
1. 四种position embedding
Position Embedding
1. Transformer
1.1 1d absolute
1.2 sin/cos constant
1.3
2. Vision Transformer
2.1 1d absolute
2.2 trainable
3. Swin Transformer
3.1 2d relative bias
3.2 trainable
4. Masked AutoEncoder
4.1 2d absolute
4.2 sin/cos constant
2. pytorch 源码[后续整理]
import torch
import torch.nn as nn
torch.set_printoptions(precision=3, sci_mode=False)
# ---------------------------------------------------------------------------------
# transformer constant sin/cos embedding position
def create_1d_absolute_sincos_embeddings(n_pos_vec, dim):
# n_pos_vec : torch.arange(n_pos,dtype=torch.float)
assert dim % 2 == 0, "wrong dimension"
position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float)
omega = torch.arange(dim // 2, dtype=torch.float)
omega /= dim / 2.0
omega = 1.0 / (10000 ** omega)
out = n_pos_vec[:, None] @ omega[None, :]
emb_sin = torch.sin(out)
emb_cos = torch.cos(out)
position_embedding[:, 0::2] = emb_sin
position_embedding[:, 1::2] = emb_cos
return position_embedding
# ---------------------------------------------------------------------------------
# ---------------------------------------------------------------------------------
# 2. 1d absolute trainable embedding
def create_1d_absolute_trainable_embeddings(n_pos_vec, dim):
# n_pos_vec : torch.arange(n_pos,dtype=torch.float)
position_embedding = nn.Embedding(n_pos_vec.numel(), dim)
nn.init.constant_(position_embedding.weight, 0.0)
return position_embedding
# 3. 2d relative bias trainable embedding
def create_2d_relative_bias_trainable_embeddings(n_heads, height, width, dim):
# width=5,-->torch.arange(5)=[0,1,2,3,4]--> bias=[-4,-3,-2,-1,0,1,2,3,4]=2*width-1
# width=5,-->torch.arange(5)=[0,1,2,3,4]--> bias=[-4,-3,-2,-1,0,1,2,3,4]=2*width-1
ps_height = (2 * height - 1) * (2 * width - 1)
ps_width = n_heads
position_embedding = nn.Embedding(ps_height, ps_width)
nn.init.constant_(position_embedding.weight, 0.0)
def get_relative_position_index(height, width):
coords = torch.stack(torch.meshgrid(torch.arange(height), torch.arange(width))) # [2,height,width]
coords_flatten = torch.flatten(coords, 1) # [2,height*width]
relative_coords_bias = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2,height*width,height*width]
relative_coords_bias[0, :, :] += height - 1
relative_coords_bias[1, :, :] += width - 1
# A:2d,B:1d,B[i*cos+j] = A[i,j]
relative_coords_bias[0, :, :] *= relative_coords_bias[1, :, :].max() + 1
return relative_coords_bias.sum(0) # [height*width,height*width]
relative_position_bias = get_relative_position_index(height, width)
bias_embedding = position_embedding(torch.flatten(relative_position_bias)).reshape(height * width, height * width,
n_heads)
bias_embedding = bias_embedding.permute(2, 0, 1).unsqueeze(0)
return bias_embedding
if __name__ == "__main__":
run_code = 0
n_pos = 4
dim = 4
n_pos_vec = torch.arange(n_pos, dtype=torch.float)
pe = create_1d_absolute_sincos_embeddings(n_pos_vec, dim)
print(f"pe=\n{pe}")
my_n_heads = 3
my_height = 4
my_width = 5
my_dim = 6
result = create_2d_relative_bias_trainable_embeddings(n_heads=my_n_heads, height=my_height, width=my_width,
dim=my_dim)
print(f"result=\n{result}")