当前位置: 首页 > article >正文

SparseMoE-2

SparseMOE代码详解

这段代码实现了一个稀疏混合专家模型(Sparse Mixture of Experts, MoE),这是一种在大型语言模型中常用的技术,可以在不显著增加推理计算量的情况下提高模型容量。下面我将详细解析代码的每个部分及其工作原理。

基本概念

混合专家模型的核心思想是:不是让所有输入都经过相同的网络层,而是有多个"专家"网络,每个输入会被动态路由到最适合它的一部分专家网络中。这样可以:

  1. 增加模型参数量而不增加推理计算量
  2. 让不同专家处理不同类型的输入,提高模型能力

代码结构与流程

1. 导入必要的库

import torch
import torch.nn as nn
import torch.nn.functional as F
from moeconfig import MOEConfig as config


这里导入PyTorch相关库和配置文件,为后续实现做准备。

### 2. BasicExpert类
```python
class BasicExpert(nn.Module):
    def __init__(self, feature_in, feature_out):
        super().__init__()
        self.fc = nn.Linear(feature_in, feature_out)

    def forward(self, x):
        return self.fc(x)

这是最基础的专家网络实现:

  • 每个专家就是一个简单的线性层(全连接层)
  • 输入维度和输出维度相同,保持特征维度不变
  • 在实际应用中,专家网络可以是更复杂的结构,如MLP或Transformer层

3. MOERouter类

class MOERouter(nn.Module):
    """为每个token计算专家选择的概率"""
    
    def __init__(self, hidden_dim, expert_number, top_k):
        super().__init__()
        self.gate = nn.Linear(hidden_dim, expert_number)
        self.expert_number = expert_number
        self.top_k = top_k

路由器是MoE的核心组件,负责决定每个token应该由哪些专家处理:

  • hidden_dim :输入特征的维度
  • expert_number :专家的总数量
  • top_k :每个token要选择的专家数量
  • self.gate :一个线性层,将token特征映射为专家选择的logits
def forward(self, hidden_states):
    # 计算每个token的logits
    router_logits = self.gate(hidden_states)  # (batchsize*seqlen, expert_number)
    
    # 计算专家经过softmax后的概率
    routing_probs = F.softmax(router_logits, dim=-1, dtype=torch.float)

路由过程第一步:

  1. 对每个token计算路由logits
  2. 通过softmax将logits转换为概率分布
    # 选择top-k个专家
    router_weights, selected_experts = torch.topk(
        routing_probs, self.top_k, dim=-1  # (b*s, expert_number)->(b*s, top_k)
    )
    
    # 专家权重归一化
    router_weights = router_weights / router_weights.sum(
        dim=-1, keepdim=True
    )  # (b*s, top_k):合为1
    router_weights = router_weights.to(hidden_states.dtype)

路由过程第二步:

  1. 使用 torch.topk 为每个token选择概率最高的k个专家
  2. 对选中的专家权重重新归一化,确保权重和为1
  3. 确保权重的数据类型与输入一致
    # 生成专家掩码
    expert_mask = F.one_hot(
        selected_experts, num_classes=self.expert_number
    )  # (b*s, top_k, expert_number)
    expert_mask = expert_mask.permute(
        2, 1, 0
    )  # (expert_number, top_k, b*s)

路由过程第三步:

  1. 将选中的专家索引转换为one-hot编码
  2. 调整维度顺序,便于后续处理
  3. 最终形状为(专家数量, top_k, batch_size*seq_len),表示每个专家负责处理哪些token

4. MOEConfig类

class MOEConfig:
    def __init__(
        self,
        hidden_dim,
        expert_number,
        top_k,
        shared_experts_number=2,
    ):
        self.hidden_dim = hidden_dim
        self.expert_number = expert_number
        self.top_k = top_k
        self.shared_experts_number = shared_experts_number

这是一个配置类,存储MoE模型的超参数:

  • hidden_dim :隐藏层维度
  • expert_number :专家数量
  • top_k :每个token选择的专家数量
  • shared_experts_number :共享专家的数量(当前代码中未使用)

5. SparseMOE类

class SparseMOE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_dim = config.hidden_dim
        self.expert_number = config.expert_number
        self.top_k = config.top_k
        self.experts = nn.ModuleList(
            [
                BasicExpert(self.hidden_dim, self.hidden_dim)
                for _ in range(self.expert_number)
            ]
        )
        self.router = MOERouter(self.hidden_dim, self.expert_number, self.top_k)

SparseMOE是整个模型的主体:

  1. 从配置中读取参数
  2. 创建多个专家网络
  3. 创建一个路由器
def forward(self, x):
    batch_size, seq_len, hidden_dim = x.size()  # x (b, s, hidden_dim)
    hidden_states = x.view(-1, hidden_dim)  # 展平:(b*s, hidden_dim)
    
    router_logits, router_weights, selected_experts_indicates, expert_mask = (
        self.router(hidden_states)
    )

前向传播第一步:

  1. 获取输入的形状信息
  2. 将输入展平为二维张量,便于处理
  3. 调用路由器,获取路由信息
    final_hidden_states = torch.zeros(
        (batch_size * seq_len, hidden_dim),
        dtype=hidden_states.dtype,
        device=hidden_states.device,
    )

前向传播第二步:

  1. 创建一个全零张量,用于存储最终输出
  2. 确保数据类型和设备与输入一致
    for expert_idx in range(self.expert_number):
        expert_layer = self.experts[expert_idx]
        idx, top_x = torch.where(
            expert_mask[expert_idx]
        )  # expert_mask[expert_idx]:(top_k,b*s)
        current_state = hidden_states.unsqueeze(0)[:, top_x, :].reshape(
            -1, hidden_dim
        )
        current_hidden_states = expert_layer(current_state) * router_weights[top_x, idx].unsqueeze(-1)
        final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

前向传播第三步(核心处理):

  1. 遍历每个专家
  2. 使用 torch.where 找出当前专家负责处理的token索引
  3. 从输入中提取这些token的特征
  4. 将特征送入专家网络处理
  5. 将专家输出乘以对应的路由权重
  6. 使用 index_add_ 将处理结果累加到最终输出中
    final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)
    return final_hidden_states, router_logits

前向传播最后一步:

  1. 将输出重新整形为原始形状
  2. 返回最终输出和路由logits(可用于计算路由损失)

6. 测试函数

def test_token_level_moe():
    x = torch.rand(2, 4, 16)
    config = MOEConfig(16, 2, 2)
    token_level_moe = SparseMOE(config)
    out = token_level_moe(x)
    print(out[0].shape, out[1].shape)

这是一个简单的测试函数:

  1. 创建一个随机输入张量,形状为(2, 4, 16),表示2个样本,每个样本4个token,每个token16维特征
  2. 创建配置:16维隐藏层,2个专家,每个token选择2个专家
  3. 创建SparseMOE模型
  4. 运行模型并打印输出形状

整体Pipeline

  1. 初始化阶段 :

    • 创建多个专家网络
    • 创建路由器网络
  2. 前向传播阶段 :

    • 输入:(batch_size, seq_len, hidden_dim)的张量
    • 路由:为每个token计算应该由哪些专家处理
    • 分发:将token分发给对应的专家处理
    • 聚合:将专家处理结果按权重聚合
    • 输出:与输入相同形状的张量
  3. 为什么使用稀疏MoE :

    • 增加模型参数量而不增加推理计算量
    • 每个token只经过少量专家,而不是所有专家
    • 不同专家可以学习处理不同类型的输入
    • 在大型语言模型中,这种方法可以显著提高模型性能

关键技术点

  1. Top-K路由 :每个token只选择概率最高的k个专家
  2. 权重归一化 :确保选中专家的权重和为1
  3. 稀疏计算 :每个专家只处理分配给它的token,而不是所有token
  4. 索引加法 :使用 index_add_ 高效地将专家输出聚合到最终结果
    这种实现方式是对Switch Transformers和GShard等论文中提出的MoE方法的简化实现,适合用于理解MoE的基本原理和工作流程。

稀疏混合专家模型(SparseMOE)的数字化流程详解

我将用具体的数字示例来详细阐述SparseMOE的整个工作流程,这样能更直观地理解代码的执行过程。

1. 输入数据

假设我们有以下输入:

  • 批次大小(batch_size): 2
  • 序列长度(seq_len): 4
  • 隐藏层维度(hidden_dim): 16
  • 专家数量(expert_number): 2
  • 每个token选择的专家数量(top_k): 2
    输入张量x的形状为: (2, 4, 16),表示2个样本,每个样本有4个token,每个token是16维向量。

2. 数据预处理

batch_size, seq_len, hidden_dim = x.size()  # (2, 4, 16)
hidden_states = x.view(-1, hidden_dim)  # (8, 16)

将输入展平为(8, 16)的二维张量,其中8 = 2×4,代表总共8个token。

3. 路由过程

3.1 计算路由logits

router_logits = self.gate(hidden_states)  # (8, 2)

假设gate层的权重矩阵为W(16×2),偏置为b(2),则:

  • 对于第1个token: logits_1 = hidden_states_1 × W + b = [0.5, 1.2]
  • 对于第2个token: logits_2 = hidden_states_2 × W + b = [1.0, 0.3]
  • 对于第8个token: logits_8 = hidden_states_8 × W + b = [0.7, 0.9]
    router_logits的形状为(8, 2),表示8个token对2个专家的原始分数。

3.2 计算路由概率

routing_probs = F.softmax(router_logits, dim=-1, dtype=torch.float)  # (8, 2)

对logits应用softmax函数:

  • 对于第1个token: probs_1 = softmax([0.5, 1.2]) = [0.32, 0.68]
  • 对于第2个token: probs_2 = softmax([1.0, 0.3]) = [0.67, 0.33]
  • 对于第8个token: probs_8 = softmax([0.7, 0.9]) = [0.45, 0.55]
    routing_probs的形状为(8, 2),表示8个token选择2个专家的概率。

3.3 选择Top-K专家

router_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)  # (8, 2)

由于top_k=2且专家数量也是2,所以每个token会选择所有专家:

  • 对于第1个token: weights_1 = [0.32, 0.68], experts_1 = [0, 1]
  • 对于第2个token: weights_2 = [0.67, 0.33], experts_2 = [0, 1]
  • 对于第8个token: weights_8 = [0.45, 0.55], experts_8 = [0, 1]
    router_weights的形状为(8, 2),表示选中专家的权重。
    selected_experts的形状为(8, 2),表示选中的专家索引。

3.4 权重归一化

router_weights = router_weights / router_weights.sum(dim=-1, keepdim=True)  # (8, 2)

由于我们选择了所有专家,权重和已经为1,所以归一化不会改变权重值。

3.5 生成专家掩码

expert_mask = F.one_hot(selected_experts, num_classes=self.expert_number)  # (8, 2, 2)
expert_mask = expert_mask.permute(2, 1, 0)  # (2, 2, 8)

对于selected_experts,生成one-hot编码:

  • 对于第1个token: mask_1 = [[[1, 0], [0, 1]]]
  • 对于第2个token: mask_2 = [[[1, 0], [0, 1]]]

  • 经过permute后,expert_mask的形状为(2, 2, 8),表示2个专家,每个专家处理哪些token。

4. 专家处理

首先创建输出张量:

final_hidden_states = torch.zeros((8, 16), dtype=hidden_states.dtype, device=hidden_states.device)

4.1 专家0的处理

expert_idx = 0
expert_layer = self.experts[0]
idx, top_x = torch.where(expert_mask[0])  # idx=[0,0,...], top_x=[0,1,...,7]

expert_mask[0]表示专家0处理的token,由于我们选择了所有专家,所以专家0处理所有8个token。

current_state = hidden_states.unsqueeze(0)[:, top_x, :].reshape(-1, hidden_dim)  # (8, 16)

提取专家0需要处理的token特征。

current_hidden_states = expert_layer(current_state) * router_weights[top_x, idx].unsqueeze(-1)  # (8, 16)

假设专家0的处理结果为:

  • 对于第1个token: output_1 = [0.1, 0.2, …, 0.1] * 0.32 = [0.032, 0.064, …, 0.032]
  • 对于第2个token: output_2 = [0.3, 0.1, …, 0.2] * 0.67 = [0.201, 0.067, …, 0.134]
final_hidden_states.index_add_(0, top_x, current_hidden_states)

将专家0的处理结果加到final_hidden_states中。

4.2 专家1的处理

expert_idx = 1
expert_layer = self.experts[1]
idx, top_x = torch.where(expert_mask[1])  # idx=[0,0,...], top_x=[0,1,...,7]

同样,专家1也处理所有8个token。

current_state = hidden_states.unsqueeze(0)[:, top_x, :].reshape(-1, hidden_dim)  # (8, 16)
current_hidden_states = expert_layer(current_state) * router_weights[top_x, idx].unsqueeze(-1)  # (8, 16)

假设专家1的处理结果为:

  • 对于第1个token: output_1 = [0.2, 0.3, …, 0.2] * 0.68 = [0.136, 0.204, …, 0.136]
  • 对于第2个token: output_2 = [0.1, 0.2, …, 0.3] * 0.33 = [0.033, 0.066, …, 0.099]
final_hidden_states.index_add_(0, top_x, current_hidden_states)

将专家1的处理结果加到final_hidden_states中。

5. 输出处理

final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)  # (2, 4, 16)

将结果重新整形为原始输入形状(2, 4, 16)。

最终,对于第1个token的输出是专家0和专家1的加权和:

  • output_1 = [0.032, 0.064, …, 0.032] + [0.136, 0.204, …, 0.136] = [0.168, 0.268, …, 0.168]

6. 返回结果

return final_hidden_states, router_logits  # (2, 4, 16), (8, 2)

返回最终的隐藏状态和路由logits。

数值流程总结

  1. 输入 : (2, 4, 16)的张量,表示2个样本,每个4个token,每个token16维
  2. 展平 : 变为(8, 16)的张量,8个token
  3. 路由 :
    • 计算logits: (8, 2),每个token对每个专家的原始分数
    • 计算概率: (8, 2),每个token选择每个专家的概率
    • 选择专家: 每个token选择2个专家(本例中是所有专家)
  4. 专家处理 :
    • 专家0处理所有token,权重不同
    • 专家1处理所有token,权重不同
    • 结果加权求和
  5. 输出 : (2, 4, 16)的张

http://www.kler.cn/a/574437.html

相关文章:

  • 数一考研复习之拉格朗日中值定理在求解函数极限中的应用,
  • 贪心算法二
  • DAIR-V2X-R数据集服务器下载
  • 社区智慧养老标准规范全解析
  • 电力杆塔倾斜监测装置:守护电网安全的智能卫士
  • 算法-回溯篇07-复原 IP 地址
  • 基于Spring Boot的健美操评分管理系统设计与实现(LW+源码+讲解)
  • DeepSeek + 沉浸式翻译 打造智能翻译助手
  • ctf网络安全比赛有一张图片怎么查看
  • 在Blender中给SP分ID通道图
  • [Python入门学习记录(小甲鱼)]第4章 分支与循环
  • Python学习第十天
  • 【搜索】P3654 First Step (ファーストステップ)
  • transformer架构解析{掩码,(自)注意力机制,多头(自)注意力机制}(含代码)-3
  • 个人博客自动化测试报告
  • Rust语言基础知识详解【七】
  • 自然语言处理:k均值聚类算法
  • nodejs去除本地文件html字符
  • 数据结构拓展:详解realloc(C++)
  • 【RabbitMQ】Spring Boot 结合 RabbitMQ 完成应用间的通信