当前位置: 首页 > article >正文

llama源码学习·model.py[6]TransformerBlock类

一、源码摘录

class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim,
            hidden_dim=4 * args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
 
    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out

二、Transformer Block作用

这个TransformerBlock类的设计允许多个这样的块可以堆叠在一起,形成一个深度的Transformer网络。每个块的输出会被用作下一个块的输入,这样的设计使得Transformer能够处理非常复杂的序列建模任务。

三、代码注释

在这里插入图片描述

class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // self.n_heads
        # 一个多头注意力模块,用于对输入执行自注意力操作。
        # 这个模块会计算输入的每个元素与其他元素之间的相互关系,并将这些关系用于更新输入。
        self.attention = Attention(args)

在这里插入图片描述

        # 一个前馈神经网络模块,包含一个 SwiGLU 激活函数 和一个线性层。
        self.feed_forward = FeedForward(
            dim = args.dim,
            hidden_dim = 4 * args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id

在这里插入图片描述

        # RMS归一化层,对注意力的输出进行归一化。
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)

在这里插入图片描述

        # RMS归一化层,对前馈神经网络的输出进行归一化。
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

在这里插入图片描述

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor, # 旋转矩阵
        mask: Optional[torch.Tensor],
    ):
        # 残差连接
        # 将注意力模块的输出与原始输入x相加,形成一个残差连接。这是一种常见的深度学习技术,
        # 可以帮助减少训练深层网络时的梯度消失问题。
        h = x + self.attention.forward(
            self.attention_norm(x), # 对输入x进行归一化,然后将归一化的x传递给注意力模块。
            start_pos, # 开始的位置
            freqs_cis, # 频率
            mask,
        )

在这里插入图片描述

        # 对结果h进行归一化,然后传递给前馈神经网络模块。
        # 前馈神经网络模块对其输入进行进一步的转换,并将输出与h相加,形成另一个残差连接。
        out = h + self.feed_forward.forward(self.ffn_norm(h))

在这里插入图片描述

        # 这个out将被用作下一个Transformer块的输入
        return out

http://www.kler.cn/a/598170.html

相关文章:

  • uni-app 与webView 互相传值
  • 内网渗透技术 Docker逃逸技术(提权)研究 CSMSF
  • IDEA批量替换项目下所有文件中的特定内容
  • 监控易运维管理软件:轻松部署,高效运维
  • mysql中的游标是什么?作用是什么?
  • 地理编码/经纬度解析/经纬度地址转换接口如何用JAVA对接
  • 大模型在非小细胞肺癌预测及治疗方案制定中的应用研究报告
  • 算力100问☞第93问:算力资源为何更分散了?
  • 算法-分治
  • Linux内核,内存分布
  • 应用程序安全趋势:左移安全、人工智能和开源恶意软件
  • 游戏引擎学习第176天
  • 修改服务器windows远程桌面默认端口号
  • 2025.03.21首板涨停股票分析
  • 机器学习-聚类模型
  • 一加13T手机三证齐全:骁龙8至尊版小屏机、80W快充
  • 5G智慧工厂专网部署:IPLOOK助力制造业数字化转型
  • 第二届图像处理与人工智能国际学术会议(ICIPAI2025)
  • setenv ethaddr b8:ae:1d:01:00:00失效错误怎么解决❌
  • Python环境安装