llama源码学习·model.py[4]Attention注意力(2)源码分析
一、源码
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(
keys, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(
values, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(
1, 2
) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
二、代码注释
1.init方法
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
n_kv_heads
是键值头的数目。如果没传 args.n_kv_heads
参,n_kv_heads
为 n_heads
注意力头的数目
model_parallel_size = fs_init.get_model_parallel_world_size()
fs_init.get_model_parallel_world_size()
获取本机GPU的数量
self.n_local_heads = args.n_heads // model_parallel_size
每个GPU上处理的注意力头的数目 = 总注意力头数目 // 总GPU数目
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
每个GPU上处理的键值头的数目 = 总键值头数目 // 总GPU数目
self.n_rep = self.n_local_heads // self.n_local_kv_heads
n_rep
表示每个 键值头 要重复的次数。
假如现在每块 GPU 上可以处理 4 个注意力头, 2 个键值头。那么 n_rep = 4 // 2 = 2
,也就是每个键值头需要在两个不同的注意力头中被重复使用。
self.head_dim = args.dim // args.n_heads
每个头的维度 = 模型的维度 // 头的数量
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wq
用于计算
Q
u
e
r
y
Query
Query 向量的权重矩阵。
ColumnParallelLinear
是一个并行版本的线性全连接层,用于在列方向上分割权重矩阵:假如有一个全连接层,权重矩阵的维度是[in_features(输入特征的数目), out_features(输出特征的数目)],ColumnParallelLinear 会将 out_features 划分成几个较小的部分,并在多个设备上分别进行计算。
args.dim
输入维度;args.n_heads * self.head_dim
输出维度(注意力头的数目乘以每个头的维度)
init_method=lambda x: x
初始化权重矩阵的函数,这里使用的是一个恒等函数表示权重矩阵在初始化后不会被改变。
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
初始化 cache_k
2.forward方法
x: torch.Tensor
输入的张量,通常是一个batch的序列数据
start_pos: int
处理非常长的序列的时候需要分批处理,start_pos
是每一批处理开始的位置
freqs_cis: torch.Tensor
旋转位置编码的旋转矩阵
mask: Optional[torch.Tensor]
用于遮住某些位置
bsz, seqlen, _ = x.shape
通过输入的x
可以获取到bsz
批次的大小,seqlen
序列的长度
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
输入 x
经过三个线性变换分别得到 xq
Q
u
e
r
y
Query
Query 向量,xk
K
e
y
Key
Key 向量,xv
V
a
l
u
e
Value
Value 向量
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
将xq\xk\xv
重塑为 四维张量,其中第三维是头的数量(对于 xq
是本地注意力头的数量,对于 xk\xv
是本地键值头的数量)第四维是每个头的维度,这样做的目的是为了后续的并行计算和注意力打分。
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
对xq, xk
应用旋转位置编码,得到带有位置信息的 xq, xk
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
将 KV_Cache 的信息移动到 xq
所在的设备上去
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
注意力机制在进行计算时,会使用到全部的历史信息(即所有的
K
e
y
Key
Key 和
V
a
l
u
e
Value
Value),因此需要将新的 xk
和 xv
添加到缓存中。这里使用的是覆盖的方式,也就是说对于当前批次 bsz
的数据,从位置 start_pos
开始,长度为 seqlen
的位置,都会被新的 xk
和 xv
覆盖。这样,当处理下一批数据时,缓存中就包含了所有已经处理过的历史信息,从而可以进行全历史范围的注意力计算。
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
先取出存储在 self.cache_k
和 self.cache_v
中的
K
e
y
Key
Key 和
V
a
l
u
e
Value
Value,这些键值包含了从序列开始到当前处理位置
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
使用 torch.matmul
计算 xq
和 keys
的乘积,这里 keys
需要在最后两个维度之间进行转置 keys.transpose(2, 3)
,因为在进行矩阵乘法时,我们需要最后一个维度和倒数第二个维度相匹配。也就是说,我们在进行 query
和 key
向量点积计算时,使用的是每个头部分的 query
和 key
。然后再除以 math.sqrt(self.head_dim)
,这是一种常见的缩放操作,用于缓解点积可能导致的梯度消失或者梯度爆炸问题。
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
如果提供了 mask
将会将 mask
加在 score
上。
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
归一化处理
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
output
原本的形状 [bsz, n_local_heads, seqlen, head_dim]
transpose(1, 2)
交换第二三维度,output
的形状变成 [bsz, seqlen, n_local_heads, head_dim]
contiguous()
让张量在内存中连续分布
view(bsz, seqlen, -1)
output
的形状变成 [bsz, seqlen, n_local_heads*head_dim]
return self.wo(output)
将output
经过定义的线性变化 wo
作为返回值
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.kler.cn/a/596757.html 如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!