当前位置: 首页 > article >正文

deepseek v3网络结构源码分析笔记

 1.网络主结构代码:主要是循环n_layers个TransformerBlock,在self.layers内构建

class Transformer(nn.Module):
    def __init__(self, args: ModelArgs):
        global world_size, rank
        world_size = dist.get_world_size() if dist.is_initialized() else 1
        rank = dist.get_rank() if dist.is_initialized() else 0
        Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
        super().__init__()
        self.max_seq_len = args.max_seq_len
        self.embed = ParallelEmbedding(args.vocab_size, args.dim)
        self.layers = torch.nn.ModuleList()
        for layer_id in range(args.n_layers):
            self.layers.append(Block(layer_id, args))
        self.norm = RMSNorm(args.dim)
        self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype())
        self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)

    @torch.inference_mode()
    def forward(self, tokens: torch.Tensor, start_pos: int = 0):
        '''
        params tokens: 输入文本内容的id表示 shape(batch_size, seq_len).
        return:输出文本词的logits表示 shape(batch_size, vocab_size)
        '''
        seqlen = tokens.size(1) # tokens数目
        h = self.embed(tokens) # tokens需要embedding转换成词向量
        freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
        mask = None
        if seqlen > 1:
            mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
        for layer in self.layers:# 多个TransformerBlock计算
            h = layer(h, start_pos, freqs_cis, mask)
        h = self.norm(h)[:, -1]
        logits = self.head(h)
        if world_size > 1:
            all_logits = [torch.empty_like(logits) for _ in range(world_size)]
            dist.all_gather(all_logits, logits)
            logits = torch.cat(all_logits, dim=-1)
        return logits

2. TransformerBlock结构:和上图类似

class Block(nn.Module):
    """
    论文中TransformerBlock的结构
    Attention部分即self.attn,采用了MLA技术
    Feed-Forward Network部分即self.ffn用的是MLP或者MOE,刚开始几个是dense_layer使用MLP,
    之后就是transerformerlayer使用MOE.  
    """
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.attn = MLA(args)
        self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
        self.attn_norm = RMSNorm(args.dim)
        self.ffn_norm = RMSNorm(args.dim)

    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
        x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
        x = x + self.ffn(self.ffn_norm(x))
        return x

3.MLP详解

class MLP(nn.Module):
    '''
    MLP就是denslayer的ffn部分,就是一系列线性变换大致是
    W2@(SILU(W1@x)*(W3@x))
    '''
    def __init__(self, dim: int, inter_dim: int):
        super().__init__()
        self.w1 = ColumnParallelLinear(dim, inter_dim)
        self.w2 = RowParallelLinear(inter_dim, dim)
        self.w3 = ColumnParallelLinear(dim, inter_dim)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

4.MLA详解

简单说来就是一种新的计算QKV的方式,原始的QKV计算是通过3个矩阵运算对hidden state分别计算QKV,KV需要缓存在网络中,现在通过一个中间步骤计算QKV,不直接缓存KV而是缓存下图阴影部分,减少了 K 和 V 矩阵的存储和计算开销

计算最终输出的时候有“navie”和“absorb”方式,代码实际用的是absorb方式,区别在于navie模式模型存贮cache的是k和v,而absorb方式存储的是kv_cache和pe_cache

原始的attention计算如下,K,V需要cache在内存中

    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
        bsz, seqlen, _ = x.size()
        end_pos = start_pos + seqlen
        # 1. 计算q
        if self.q_lora_rank == 0:
            q = self.wq(x)
        else:
            q = self.wq_b(self.q_norm(self.wq_a(x)))
        q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
        # 2. 拆分q
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        q_pe = apply_rotary_emb(q_pe, freqs_cis)
        # 3.计算kv
        kv = self.wkv_a(x)
        # 4.拆分k
        kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
        if attn_impl == "naive": # 正常的kv cache
            q = torch.cat([q_nope, q_pe], dim=-1)
            kv = self.wkv_b(self.kv_norm(kv))
            kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
            k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
            k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
            self.k_cache[:bsz, start_pos:end_pos] = k
            self.v_cache[:bsz, start_pos:end_pos] = v
            scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
        else: # 实际运行的是这里,可以cache的不再是完整的kv结果
            wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) 
            wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
            q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
            self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
            self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
            scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
                      torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
        if mask is not None:
            scores += mask.unsqueeze(1)
        scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
        if attn_impl == "naive":
            x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
        else:
            x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
            x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
        x = self.wo(x.flatten(2))
        return x

5.MOE详解


http://www.kler.cn/a/538744.html

相关文章:

  • Python基础-元组tuple的学习
  • Oracle迁移到MySQL
  • 多光谱成像技术在华为Mate70系列的应用
  • LIMO:上海交大的工作 “少即是多” LLM 推理
  • 轻松理解CSS中的float浮动元素
  • Android性能优化
  • 网络基础之IP
  • NUMA 配置对 Redis 使用的影响:提升性能的秘密武器
  • 【PyQt5 12】如何加载QT designer 设计的界面
  • docker /var/lib/docker/overlay2目录把磁盘空间占满问题
  • 【WebLogic】Linux图形化界面创建WebLogic应用域
  • 25/2/7 <机器人基础> 牛顿-欧拉递推公式,开闭环
  • 常用在线工具
  • 无人机方位感知器官磁力传感器!
  • 【数据结构】链表应用-链表重新排序
  • 【后端java】构建工具maven
  • 使用云效解决docker官方镜像拉取不到的问题
  • react 19 useOptimistic 竞争更新乐观值时阻塞
  • Qt的QTableWidget类的声明定义和使用
  • Android13-系统服务大管家-ServiceManager进程-启动篇
  • 具身智能学习规划
  • 【LeetCode: 525. 连续数组 + 前缀和 + 哈希表】
  • CodeGPT + IDEA + DeepSeek,在IDEA中引入DeepSeek实现AI智能开发
  • android动态设置是否允许应用卸载
  • ES管理器焕新升级:紫色银狼主题来袭!
  • 在 Navicat 17 中扩展 PostgreSQL 数据类型 | 复合类型