当前位置: 首页 > article >正文

Llama架构及代码详解

Llama的框架图如图:
在这里插入图片描述
源码中含有大量分布式训练相关的代码,读起来比较晦涩难懂,所以我们对llama自顶向下进行了解析及复现,我们对其划分成三层,分别是顶层、中层、和底层,如下:

Llama的整体组成

由上图可知,Llama整体是由1个embedding层,n个transformer层,和1个RMSNorm层组成的,所以顶层代码如下:
顶层

class Llama(torch.nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()
       self.config = config
        # embedding层
        self.tok_embeddings = torch.nn.Embedding(self.config.vocab_size, self.config.dim)
        # RMSNorm
        self.norm = RMSNorm(config.dim, eps=config.norm_eps)
        # n层Transformer
        self.layers = torch.nn.ModuleList()
        for i in range(self.config.n_layers):
            self.layers.append(TransformerBlock(config))


    def forward(self, tokens):
        # 进行token的嵌入编码
        h = self.tok_embeddings(tokens)
        # decoder架构需要生成一个mask
        seqlen = h.shape[1]
        mask = torch.full((seqlen, seqlen), float('-inf'), device=tokens.device)
        mask = torch.triu(mask, diagonal=1)
        # 进行n层Transformer
        for i in range(self.config.n_layers):
            h = self.layers[i](h, mask)
        # 进行RMSNorm
        token_embeddings = self.norm(h)
        return token_embeddings

中层
我们首先进行RMSNorm的复现

class RMSNorm(torch.nn.Module):
    def __init__(self, dim, eps):
        super().__init__()
        self.eps = eps
        self.weight = torch.nn.Parameter(torch.ones(dim))

    def _norm(self, tensor):
        return tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, tensor):
        output = self._norm(tensor)
        return output * self.weight

然后对Transformer进行复现,在Transformer中,Transformer包括两个RMSNorm层,一个多头attention层,一个全连接层。

class TransformerBlock(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # 多头注意力层
        self.attention = Attention(config)
        # Norm层
        self.attention_normal = RMSNorm(config.dim, config.norm_eps)
        self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
        # 全连接层
        self.ffn = FeedForwad(self.config.dim, self.config.dim * 4)

    def forward(self, embeddings, mask):
        # norm
        h = self.attention_normal(embeddings)
        # attention
        h = self.attention(h, mask)
        # add & norm
        h = self.ffn_norm(h + embeddings)
        # fnn
        f = self.ffn(h)
        # add
        return f + h

底层
在多头attention中,首先需要对token的嵌入进行空间映射,多头拆分,旋转位置编码,分数计算等操作

class Attention(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.n_head = config.n_heads
        self.dim = config.dim // self.n_head

        self.k = torch.nn.Linear(config.dim, config.dim)
        self.q = torch.nn.Linear(config.dim, config.dim)
        self.v = torch.nn.Linear(config.dim, config.dim)

    def forward(self, embeddings, mask):
        bsz, seq_len, dim = embeddings.shape

        k_embeddings = self.k(embeddings)
        q_embeddings = self.q(embeddings)
        v_embeddings = self.v(embeddings)
        n_q_embeddings = q_embeddings.reshape(bsz, -1, self.n_head, self.dim).permute(0, 2, 1, 3)
        n_k_embeddings = k_embeddings.reshape(bsz, -1, self.n_head, self.dim).permute(0, 2, 1, 3)
        n_v_embeddings = v_embeddings.reshape(bsz, -1, self.n_head, self.dim).permute(0, 2, 1, 3)

        rotated_n_q_embeddings = compute_rotated_embedding(n_q_embeddings, self.dim, seq_len, self.config.rope_theta)
        rotated_n_k_embeddings = compute_rotated_embedding(n_k_embeddings, self.dim, seq_len, self.config.rope_theta)

        scores = torch.nn.functional.softmax(mask + rotated_n_q_embeddings @ rotated_n_k_embeddings.transpose(-1, -2)
                               / math.sqrt(self.dim), dim=-1)

        n_embeddings = scores @ n_v_embeddings
        embeddings = n_embeddings.permute(0, 2, 1, 3).reshape(bsz, -1, self.config.dim)

        return embeddings
class FeedForwad(torch.nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.linear1 = torch.nn.Linear(dim, hidden_dim)
        self.linear2 = torch.nn.Linear(dim, hidden_dim)
        self.linear3 = torch.nn.Linear(hidden_dim, dim)

    def forward(self, embeddings):
        gate = torch.nn.functional.silu(self.linear1(embeddings))
        up_proj = self.linear2(embeddings) * gate
        return self.linear3(up_proj)

最后,我们复现旋转位置编码,至此我们捋清了llama的所有结构!

def compute_rotated_embedding(embedding, dim, m, base):
    # 计算所有嵌入位置的旋转角度
    all_theta = compute_all_theta(dim, m, base)
    # 旋转后嵌入位置 = 复数平面上初始位置 * 复数平面上角度坐标
    # 1、将嵌入投影到复数平面
    embedding_real_pair = embedding.reshape(*embedding.shape[:-1], -1, 2)
    embedding_complex_pair = torch.view_as_complex(embedding_real_pair)
    # 2、将旋转角度投影到复数平面
    all_theta = all_theta[: embedding.shape[-2]]
    theta_complex_pair = torch.polar(torch.ones_like(all_theta), all_theta)
    # 3、旋转后嵌入位置 = 复数平面上初始位置 * 复数平面上角度坐标
    rotated_complex_embedding = embedding_complex_pair * theta_complex_pair
    # 4、将复数平面的嵌入投影到实数平面
    rotated_real_embedding = torch.view_as_real(rotated_complex_embedding)
    rotated_real_embedding = rotated_real_embedding.reshape(*embedding.shape[:-1], -1)
    return rotated_real_embedding

def compute_all_theta(dim, m, base):
    theta = 1 / (base ** (torch.arange(0, dim / 2).float() / (dim / 2)))
    m = torch.arange(0, m)
    all_theta = torch.outer(m, theta)
    return all_theta

附录:llama的config参数

@dataclass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocab_size: int = -1
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5
    rope_theta: float = 500000

    max_batch_size: int = 32
    max_seq_len: int = 2048
    use_scaled_rope: bool = True

http://www.kler.cn/a/392597.html

相关文章:

  • java 上传txt json等类型文件解析后返回给前端
  • 大型ERP系统GL(总账管理)模块需求分析
  • node-sass安装报错,换成sass
  • SAKO搜索帮助增强(FB02科目搜索帮助)
  • 开启Rancher学习之旅:从入门到精通的成长计划
  • k8s-1.28.2 部署prometheus
  • 平衡二叉树、红黑树、B树、B+树
  • 鸿蒙next版开发:相机开发-会话管理(ArkTS)
  • HTB:Precious[WriteUP]
  • 计算机网络——1.2计算机网络的组成
  • SpringBoot赋能的共享汽车业务管理系统
  • LeetCode【0022】括号生成
  • 腾讯云产品推荐----域名的使用
  • 【时间之外】IT人求职和创业应知【31】
  • 万字长文解读深度学习——ViT、ViLT、DiT
  • 【go从零单排】Text Templates
  • 单体架构VS微服务架构
  • 高阶函数全解析(定义、应用 -- 函数柯理化 反柯理化 发布订阅模式 观察者模式)
  • 执行npm run build -- --report后,生产report.html文件是什么?
  • kafka是如何处理数据乱序问题的?
  • Java代码操作ZooKeeper(使用原生 ZooKeeper 客户端库)
  • UE5 设置Sequence播完后返回起始位置
  • hadoop报错找不到主类
  • 苹果低价版Vision Pro 推迟至2027年发布:XR领域的变局与挑战
  • TypeORM在Node.js中的应用
  • 缓存雪崩问题及解决方法