当前位置: 首页 > article >正文

llama源码学习·model.py[4]Attention注意力(2)源码分析

一、源码

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        model_parallel_size = fs_init.get_model_parallel_world_size()
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads

        self.wq = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wk = ColumnParallelLinear(
            args.dim,
            self.n_kv_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wv = ColumnParallelLinear(
            args.dim,
            self.n_kv_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wo = RowParallelLinear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False,
            input_is_parallel=True,
            init_method=lambda x: x,
        )

        self.cache_k = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()
        self.cache_v = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        # repeat k/v heads if n_kv_heads < n_heads
        keys = repeat_kv(
            keys, self.n_rep
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)
        values = repeat_kv(
            values, self.n_rep
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)

        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2)  # (bs, n_local_heads, cache_len + seqlen, head_dim)
        values = values.transpose(
            1, 2
        )  # (bs, n_local_heads, cache_len + seqlen, head_dim)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)

二、代码注释

1.init方法

self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads

n_kv_heads 是键值头的数目。如果没传 args.n_kv_heads 参,n_kv_headsn_heads 注意力头的数目

model_parallel_size = fs_init.get_model_parallel_world_size()

fs_init.get_model_parallel_world_size() 获取本机GPU的数量

self.n_local_heads = args.n_heads // model_parallel_size

每个GPU上处理的注意力头的数目 = 总注意力头数目 // 总GPU数目

self.n_local_kv_heads = self.n_kv_heads // model_parallel_size

每个GPU上处理的键值头的数目 = 总键值头数目 // 总GPU数目

self.n_rep = self.n_local_heads // self.n_local_kv_heads

n_rep 表示每个 键值头 要重复的次数。

假如现在每块 GPU 上可以处理 4 个注意力头, 2 个键值头。那么 n_rep = 4 // 2 = 2 ,也就是每个键值头需要在两个不同的注意力头中被重复使用。

self.head_dim = args.dim // args.n_heads

每个头的维度 = 模型的维度 // 头的数量

self.wq = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )

self.wq 用于计算 Q u e r y Query Query 向量的权重矩阵。

ColumnParallelLinear 是一个并行版本的线性全连接层,用于在列方向上分割权重矩阵:假如有一个全连接层,权重矩阵的维度是[in_features(输入特征的数目), out_features(输出特征的数目)],ColumnParallelLinear 会将 out_features 划分成几个较小的部分,并在多个设备上分别进行计算。

args.dim 输入维度;args.n_heads * self.head_dim 输出维度(注意力头的数目乘以每个头的维度)

init_method=lambda x: x 初始化权重矩阵的函数,这里使用的是一个恒等函数表示权重矩阵在初始化后不会被改变。

        self.cache_k = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()

初始化 cache_k

2.forward方法

x: torch.Tensor

输入的张量,通常是一个batch的序列数据

start_pos: int

处理非常长的序列的时候需要分批处理,start_pos 是每一批处理开始的位置

freqs_cis: torch.Tensor

旋转位置编码的旋转矩阵

mask: Optional[torch.Tensor]

用于遮住某些位置

bsz, seqlen, _ = x.shape

通过输入的x 可以获取到bsz 批次的大小,seqlen 序列的长度

xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

输入 x 经过三个线性变换分别得到 xq Q u e r y Query Query 向量,xk K e y Key Key 向量,xv V a l u e Value Value 向量

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

xq\xk\xv 重塑为 四维张量,其中第三维是头的数量(对于 xq 是本地注意力头的数量,对于 xk\xv 是本地键值头的数量)第四维是每个头的维度,这样做的目的是为了后续的并行计算和注意力打分。

xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

xq, xk 应用旋转位置编码,得到带有位置信息的 xq, xk

        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

将 KV_Cache 的信息移动到 xq 所在的设备上去

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

注意力机制在进行计算时,会使用到全部的历史信息(即所有的 K e y Key Key V a l u e Value Value),因此需要将新的 xkxv 添加到缓存中。这里使用的是覆盖的方式,也就是说对于当前批次 bsz 的数据,从位置 start_pos 开始,长度为 seqlen 的位置,都会被新的 xkxv 覆盖。这样,当处理下一批数据时,缓存中就包含了所有已经处理过的历史信息,从而可以进行全历史范围的注意力计算。

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

先取出存储在 self.cache_kself.cache_v 中的 K e y Key Key V a l u e Value Value,这些键值包含了从序列开始到当前处理位置

        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)

使用 torch.matmul 计算 xqkeys 的乘积,这里 keys 需要在最后两个维度之间进行转置 keys.transpose(2, 3),因为在进行矩阵乘法时,我们需要最后一个维度和倒数第二个维度相匹配。也就是说,我们在进行 querykey 向量点积计算时,使用的是每个头部分的 querykey。然后再除以 math.sqrt(self.head_dim),这是一种常见的缩放操作,用于缓解点积可能导致的梯度消失或者梯度爆炸问题。

        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)

如果提供了 mask 将会将 mask 加在 score 上。

        scores = F.softmax(scores.float(), dim=-1).type_as(xq)

归一化处理

        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

output 原本的形状 [bsz, n_local_heads, seqlen, head_dim]

transpose(1, 2) 交换第二三维度,output 的形状变成 [bsz, seqlen, n_local_heads, head_dim]

contiguous() 让张量在内存中连续分布

view(bsz, seqlen, -1) output 的形状变成 [bsz, seqlen, n_local_heads*head_dim]

        return self.wo(output)

output经过定义的线性变化 wo 作为返回值

原文地址:https://blog.csdn.net/m0_72851153/article/details/146446325
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.kler.cn/a/596757.html

相关文章:

  • 洛谷 [语言月赛 202503] 题解(C++)
  • (滑动窗口)算法训练篇11--力扣3.无重复字符的最长字串(难度中等)
  • ROM(只读存储器) 、SRAM(静态随机存储器) 和 Flash(闪存) 的详细解析
  • Centos编译升级libcurl
  • DeepSeek自学手册:《从理论(模型训练)到实践(模型应用)》|73页|附PPT下载方法
  • NVM 多版本node.js管理工具
  • Linux用户管理实操指南
  • 【 <二> 丹方改良:Spring 时代的 JavaWeb】之 Spring Boot 中的异常处理:全局异常与自定义异常
  • Ubuntu 系统安装 Redis 的详细步骤
  • Android13音频子系统分析(四)---座舱的多音区框架
  • 亮相AWE2025,MOVA以科技重塑生活,以美学沟通世界
  • go:前后端分离
  • Agent Team 多智能体系统解析
  • 【redis】事务详解,相关命令multi、exec、discard 与 watch 的原理
  • 嵌入式系统的核心组成部分处理器、存储器、传感器和执行器
  • 正则表达式详解(regular expression)
  • 掌握 Zapier:从入门到精通的自动化指南
  • 企业选择网站服务器租用需要注意哪些?
  • iptables和netfilter内部报文处理
  • 好未来25校招Web前端开发工程师部分笔试题解析