多头自注意力中的多头作用及相关思考
文章目录
- 1. num_heads
- 2. pytorch源码演算
1. num_heads
将矩阵的最后一维度进行按照num_heads的方式进行切割矩阵,具体表示如下:
2. pytorch源码演算
- pytorch 代码
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_printoptions(precision=3, sci_mode=False)
if __name__ == "__main__":
run_code = 0
batch_size = 2
seq_len = 4
model_dim = 6
num_heads = 3
mat_total = batch_size * seq_len * model_dim
mat1 = torch.arange(mat_total).reshape((batch_size, seq_len, model_dim))
print(f"mat1=\n{mat1}")
head_dim = model_dim // num_heads
mat2 = mat1.reshape((batch_size, seq_len, num_heads, head_dim))
print(f"mat2=\n{mat2}")
mat3 = mat2.transpose(1, 2)
print(f"mat3=\n{mat3}")
mat4 = mat3.reshape((batch_size*num_heads,seq_len,head_dim))
print(f"mat1.shape=\n{mat1.shape}")
print(f"mat1=\n{mat1}")
print(f"mat4.shape=\n{mat4.shape}")
print(f"mat4=\n{mat4}")
- 结果:
mat1=
tensor([[[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23]],
[[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35],
[36, 37, 38, 39, 40, 41],
[42, 43, 44, 45, 46, 47]]])
mat2=
tensor([[[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]]],
[[[24, 25],
[26, 27],
[28, 29]],
[[30, 31],
[32, 33],
[34, 35]],
[[36, 37],
[38, 39],
[40, 41]],
[[42, 43],
[44, 45],
[46, 47]]]])
mat3=
tensor([[[[ 0, 1],
[ 6, 7],
[12, 13],
[18, 19]],
[[ 2, 3],
[ 8, 9],
[14, 15],
[20, 21]],
[[ 4, 5],
[10, 11],
[16, 17],
[22, 23]]],
[[[24, 25],
[30, 31],
[36, 37],
[42, 43]],
[[26, 27],
[32, 33],
[38, 39],
[44, 45]],
[[28, 29],
[34, 35],
[40, 41],
[46, 47]]]])
mat1.shape=
torch.Size([2, 4, 6])
mat1=
tensor([[[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23]],
[[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35],
[36, 37, 38, 39, 40, 41],
[42, 43, 44, 45, 46, 47]]])
mat4.shape=
torch.Size([6, 4, 2])
mat4=
tensor([[[ 0, 1],
[ 6, 7],
[12, 13],
[18, 19]],
[[ 2, 3],
[ 8, 9],
[14, 15],
[20, 21]],
[[ 4, 5],
[10, 11],
[16, 17],
[22, 23]],
[[24, 25],
[30, 31],
[36, 37],
[42, 43]],
[[26, 27],
[32, 33],
[38, 39],
[44, 45]],
[[28, 29],
[34, 35],
[40, 41],
[46, 47]]])
- 思考: 在矩阵y=Ax表示的时候,如果我们无法用Ax整体表示y的时候,我们可以通过将矩阵A的列向量进行拆分后得到A1,A2,A3,这样y=(A1,A2,A3)x表示更合理。