当前位置: 首页 > article >正文

对PosWiseFFN的改进: MoE、PKM、UltraMem

先从PosWiseFFN说起

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.GeLU(),
            nn.Linear(d_ff, d_model, bias=False))

    def forward(self, inputs):                                  # inputs: [batch_size, seq_len, d_model]
        residual = inputs
        output = self.fc(inputs)
        return nn.LayerNorm(d_model)(output + residual)  # [batch_size, seq_len, d_model]

如果Attention的维度是d_model,通常PosWiseFFN模型结构就是2个矩阵中间加个Gelu,d_ff是d_model的4倍:第1个矩阵的weight是[d_model, 4*d_model],第2个矩阵的的weight是[4*d_model, d_model]。

PosWiseFFN这个结构也可以理解成一种qkv查询的思路,如果第1个矩阵理解成key,第二矩阵理解成value,那么输入就是[batch_size, seq_len, d_model]的input作为query先去和key做矩阵乘法,得到一个[batch_size, seq_len, 4*d_model]的dots,这个dots过了GeLU后再去和[4*d_model, d_model]的第二个矩阵相乘,这一步变向取了前d_model重要的结果。问题来了,能不能把 4*d_model的d_ff给变得更大呢?Figure 1来自Large Memory Layers with Product Keys的Figure1,图里的|K|在PosWiseFFN里就是 4*d_model。
在这里插入图片描述

下面的PKM简单来说就是把这种qkv查询的思路借用PQ的思想给改进了

PKM(Product Key Memory,这个Product其实就是Product Quantization的Product)

在Large Memory Layers with Product Keys的Figure1里,q的shape是[…,d_model],k的shape是[d_model, |K|],下面看Figure2里怎么解决|K|过大的问题?图里把d_model维的q劈成q1和q2,q1和q2的维度分别是d_model/2;同样的,把[d_model, |K|]的keys劈成[d_model/2, |K|]的sub-key set 1(下图里不带’的 c 1 c_1 c1, c 2 c_2 c2, c 3 c_3 c3)和[d_model/2, |K|]的sub-key set 2(下图里带’的 c 1 ′ c^{'}_1 c1, c 2 ′ c^{'}_2 c2, c 3 ′ c^{'}_3 c3)。这样两半都出topk,最后从 k 2 k^2 k2里再选出k个,这就是Product Quantization的思想
在这里插入图片描述

代码赏析

代码来自https://github.com/lucidrains/product-key-memory/tree/master,里面einops用的不错,下面给一些注释:

class PKM(nn.Module):
    def __init__(
        self,
        dim,
        heads = 4,
        num_keys = 128,
        topk = 32,
        dim_head = 128,
        input_dropout = 0.,
        query_dropout = 0.,
        value_dropout = 0.,
        attn_dropout = 0.,
        use_layernorm = True,
        pre_layernorm = False,
        differentiable_topk = False,
        concat_values_and_combine = False,
        norm_output = False,
        non_competitive_gates = False # Csordas et al. claims non-competitive gates work even better
    ):
        super().__init__()
        self.topk = topk
        self.heads = heads
        self.num_keys = num_keys
        dim_query = dim_head * heads * 2
        self.to_queries = nn.Linear(dim, dim_query, bias = False)

        # pre-layernorm pattern
        self.pre_layernorm = nn.LayerNorm(dim) if pre_layernorm else nn.Identity()

        # batchnorm would break causality
        self.use_layernorm = use_layernorm

        if use_layernorm:
            self.norm = nn.LayerNorm(dim_head)
        else:
            self.norm = MaskedBatchNorm1D(nn.BatchNorm1d(dim_head))

        # keys
        self.keys = nn.Parameter(torch.zeros(heads, num_keys, 2, dim_head))
        init_(self.keys)

        # values
        self.concat_values_and_combine = concat_values_and_combine
        if concat_values_and_combine:
            values = nn.Embedding(num_keys ** 2, dim_head)

            self.values = nn.Sequential(
                values,
                Reduce('b (h k) d -> b h d', 'sum', h = heads),
                Rearrange('b n d -> b (n d)'),
                nn.Linear(dim_head * heads, dim, bias = False)
            )
        else:
            values = nn.EmbeddingBag(num_keys ** 2, dim, mode = 'sum')
            self.values = values
        init_(values.weight)

        # dropouts
        self.input_dropout = nn.Dropout(input_dropout)
        self.query_dropout = nn.Dropout(query_dropout)
        self.value_dropout = nn.Dropout(value_dropout)
        self.attn_dropout = nn.Dropout(attn_dropout)

        # non competitive gates
        self.gate_activation = nn.Softmax(dim = -1) if not non_competitive_gates else nn.ReLU()
        # use a differentiable topk, based on coordinate descent
        self.differentiable_topk = differentiable_topk
        # https://arxiv.org/abs/2302.06461
        # claims to boost performance of softmax key / value networks by simply layernorming the output
        self.output_norm = nn.LayerNorm(dim) if norm_output else nn.Identity()

    def forward(
        self,
        x,
        input_mask = None,
        gumbel_noise_scale = 0.,
        **kwargs
    ):
        b, t, h = *x.shape[:2], self.heads

        x = self.pre_layernorm(x)
        x = self.input_dropout(x)

        queries = self.to_queries(x)

        #写一下queries的shape: b=batch_size, t=target_seq_len, p=partition, h=num_heads, d=head_dim
        queries = rearrange(queries, 'b t (p h d) -> (b p h) t d', p = 2, h = h)

        # norm and dropout queries
        norm_kwargs = dict(mask = input_mask) if not self.use_layernorm else dict()
        queries = self.norm(queries, **norm_kwargs)
        queries = self.query_dropout(queries)

        queries = rearrange(queries, '(b p h) t d -> p b t h d', p = 2, h = h)

        # similarity to keys
        # keys.shape:heads, num_keys, 2, dim_head。这里的n是keys的batch_size
        # 这里的keys本质上是一个二维数组
        dots = einsum('p b t h d, h n p d -> b t h p n', queries, self.keys)

        # gumbel noise
        if gumbel_noise_scale > 0.:
            dots = dots + gumbel_noise(dots) * gumbel_noise_scale

        # topk scores
        if self.differentiable_topk:
            scores, indices, *_ = coor_descent_topk(dots, k = self.topk, fused = True)
        else:
            scores, indices = dots.topk(k = self.topk, dim = -1)
        # scores are factorized
        (scores_x, scores_y), (indices_x, indices_y) = map(lambda t: t.chunk(2, dim = 3), (scores, indices))

        all_topk = self.topk ** 2

        all_scores = rearrange((
            rearrange(scores_x, '... k -> ... k 1') +
            rearrange(scores_y, '... k -> ... 1 k')
        ), 'b t h ... -> b t h (...)')

        all_indices = rearrange((
            rearrange(indices_x, '... k -> ... k 1') * self.num_keys +
            rearrange(indices_y, '... k -> ... 1 k')
        ), 'b t h ... -> b t h (...)')

        final_topk, final_indices = all_scores.topk(self.topk, dim=-1)
        value_indices = all_indices.gather(-1, final_indices)

        # attention

        attn = self.gate_activation(final_topk)
        attn = self.attn_dropout(attn)

        value_indices, attn = map(lambda t: rearrange(t, 'b t h k -> (b t) (h k)'), (value_indices, attn))

        # aggregate

        if self.concat_values_and_combine:
            out = self.values(value_indices)
        else:
            out = self.values(value_indices, per_sample_weights = attn)

        out = self.value_dropout(out)

        # maybe layernorm the output

        out = self.output_norm(out)

        return rearrange(out, '(b t) d -> b t d', b = b)

UltraMem

来自ULTRA-SPARSE MEMORY NETWORK,字节发这个时候吹“有效解决了MoE推理时高额的访存问题,推理速度较MoE架构提升2-6倍,推理成本最高可降低83%”,猛地一看以为把DeepSeekMoE又给提升了2-6倍,可实际上是下面这个MoE的paper。UltraMem的思路实际上是对PKM思路的一种改进,但字节并没有公布源代码,也不知道他们家的智障豆包用上了没,先摘录一些核心想法,等代码出了再仔细拜读。
在这里插入图片描述
为了解决drawback1和drawback3,把PQ改成了下面的TDQKR,一种基于SVD分解的方法:
在这里插入图片描述

MoE

这个MoE不同于MoE架构LLM中的MoE,而是对PosWiseFFN的改进,来自于Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity,以下是论文中的截图,看一眼就知道大致的思路:
在这里插入图片描述

附录:

  1. https://mp.weixin.qq.com/s/BPGbzAQ5AKPj7yqrOCCuGQ?token=2117558689&lang=zh_CN
  2. https://team.doubao.com/zh/publication/ultra-sparse-memory-network?view_from=research
  3. https://www.cls.cn/detail/1940788

http://www.kler.cn/a/546103.html

相关文章:

  • Python 调用 DeepSeek API 案例详细教程
  • untiy 冰面与地面,物理材质的影响
  • Python常见面试题的详解4
  • uvc预览分析
  • 南京观海微电子----整流滤波电路实用
  • P2704 [NOI2001] 炮兵阵地
  • 番外04:前端面试八股文-HTML
  • 无人机信号调制技术原理
  • 北斗导航 | 周跳探测算法(matlab源码)
  • 高并发场景下,如何用无锁实现高性能LRU缓存?
  • MySQL SQL优化策略:全面提升查询性能的实用技巧与深入分析
  • 数据分析——动态分配内存、结构体
  • STM32单片机芯片与内部85 RS232 RS485 UART ISP下载硬件选择 电路设计 IO分配
  • python学opencv|读取图像(六十八)使用cv2.Canny()函数实现图像边缘检测
  • 3dtiles——Cesium ion for Autodesk Revit Add-In插件
  • Linux 文件系统:恢复已删除文件的挑战
  • HTTP/2 ddos攻击脚本
  • Pytorch深度学习教程_1_Python基础快速掌握
  • Python用PyMC3马尔可夫链蒙特卡罗MCMC对疾病症状数据贝叶斯推断
  • wps配置deepseek