SparseMoE-2
SparseMOE代码详解
这段代码实现了一个稀疏混合专家模型(Sparse Mixture of Experts, MoE),这是一种在大型语言模型中常用的技术,可以在不显著增加推理计算量的情况下提高模型容量。下面我将详细解析代码的每个部分及其工作原理。
基本概念
混合专家模型的核心思想是:不是让所有输入都经过相同的网络层,而是有多个"专家"网络,每个输入会被动态路由到最适合它的一部分专家网络中。这样可以:
- 增加模型参数量而不增加推理计算量
- 让不同专家处理不同类型的输入,提高模型能力
代码结构与流程
1. 导入必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
from moeconfig import MOEConfig as config
这里导入PyTorch相关库和配置文件,为后续实现做准备。
### 2. BasicExpert类
```python
class BasicExpert(nn.Module):
def __init__(self, feature_in, feature_out):
super().__init__()
self.fc = nn.Linear(feature_in, feature_out)
def forward(self, x):
return self.fc(x)
这是最基础的专家网络实现:
- 每个专家就是一个简单的线性层(全连接层)
- 输入维度和输出维度相同,保持特征维度不变
- 在实际应用中,专家网络可以是更复杂的结构,如MLP或Transformer层
3. MOERouter类
class MOERouter(nn.Module):
"""为每个token计算专家选择的概率"""
def __init__(self, hidden_dim, expert_number, top_k):
super().__init__()
self.gate = nn.Linear(hidden_dim, expert_number)
self.expert_number = expert_number
self.top_k = top_k
路由器是MoE的核心组件,负责决定每个token应该由哪些专家处理:
- hidden_dim :输入特征的维度
- expert_number :专家的总数量
- top_k :每个token要选择的专家数量
- self.gate :一个线性层,将token特征映射为专家选择的logits
def forward(self, hidden_states):
# 计算每个token的logits
router_logits = self.gate(hidden_states) # (batchsize*seqlen, expert_number)
# 计算专家经过softmax后的概率
routing_probs = F.softmax(router_logits, dim=-1, dtype=torch.float)
路由过程第一步:
- 对每个token计算路由logits
- 通过softmax将logits转换为概率分布
# 选择top-k个专家
router_weights, selected_experts = torch.topk(
routing_probs, self.top_k, dim=-1 # (b*s, expert_number)->(b*s, top_k)
)
# 专家权重归一化
router_weights = router_weights / router_weights.sum(
dim=-1, keepdim=True
) # (b*s, top_k):合为1
router_weights = router_weights.to(hidden_states.dtype)
路由过程第二步:
- 使用 torch.topk 为每个token选择概率最高的k个专家
- 对选中的专家权重重新归一化,确保权重和为1
- 确保权重的数据类型与输入一致
# 生成专家掩码
expert_mask = F.one_hot(
selected_experts, num_classes=self.expert_number
) # (b*s, top_k, expert_number)
expert_mask = expert_mask.permute(
2, 1, 0
) # (expert_number, top_k, b*s)
路由过程第三步:
- 将选中的专家索引转换为one-hot编码
- 调整维度顺序,便于后续处理
- 最终形状为(专家数量, top_k, batch_size*seq_len),表示每个专家负责处理哪些token
4. MOEConfig类
class MOEConfig:
def __init__(
self,
hidden_dim,
expert_number,
top_k,
shared_experts_number=2,
):
self.hidden_dim = hidden_dim
self.expert_number = expert_number
self.top_k = top_k
self.shared_experts_number = shared_experts_number
这是一个配置类,存储MoE模型的超参数:
- hidden_dim :隐藏层维度
- expert_number :专家数量
- top_k :每个token选择的专家数量
- shared_experts_number :共享专家的数量(当前代码中未使用)
5. SparseMOE类
class SparseMOE(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_dim
self.expert_number = config.expert_number
self.top_k = config.top_k
self.experts = nn.ModuleList(
[
BasicExpert(self.hidden_dim, self.hidden_dim)
for _ in range(self.expert_number)
]
)
self.router = MOERouter(self.hidden_dim, self.expert_number, self.top_k)
SparseMOE是整个模型的主体:
- 从配置中读取参数
- 创建多个专家网络
- 创建一个路由器
def forward(self, x):
batch_size, seq_len, hidden_dim = x.size() # x (b, s, hidden_dim)
hidden_states = x.view(-1, hidden_dim) # 展平:(b*s, hidden_dim)
router_logits, router_weights, selected_experts_indicates, expert_mask = (
self.router(hidden_states)
)
前向传播第一步:
- 获取输入的形状信息
- 将输入展平为二维张量,便于处理
- 调用路由器,获取路由信息
final_hidden_states = torch.zeros(
(batch_size * seq_len, hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
前向传播第二步:
- 创建一个全零张量,用于存储最终输出
- 确保数据类型和设备与输入一致
for expert_idx in range(self.expert_number):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(
expert_mask[expert_idx]
) # expert_mask[expert_idx]:(top_k,b*s)
current_state = hidden_states.unsqueeze(0)[:, top_x, :].reshape(
-1, hidden_dim
)
current_hidden_states = expert_layer(current_state) * router_weights[top_x, idx].unsqueeze(-1)
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
前向传播第三步(核心处理):
- 遍历每个专家
- 使用 torch.where 找出当前专家负责处理的token索引
- 从输入中提取这些token的特征
- 将特征送入专家网络处理
- 将专家输出乘以对应的路由权重
- 使用 index_add_ 将处理结果累加到最终输出中
final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)
return final_hidden_states, router_logits
前向传播最后一步:
- 将输出重新整形为原始形状
- 返回最终输出和路由logits(可用于计算路由损失)
6. 测试函数
def test_token_level_moe():
x = torch.rand(2, 4, 16)
config = MOEConfig(16, 2, 2)
token_level_moe = SparseMOE(config)
out = token_level_moe(x)
print(out[0].shape, out[1].shape)
这是一个简单的测试函数:
- 创建一个随机输入张量,形状为(2, 4, 16),表示2个样本,每个样本4个token,每个token16维特征
- 创建配置:16维隐藏层,2个专家,每个token选择2个专家
- 创建SparseMOE模型
- 运行模型并打印输出形状
整体Pipeline
-
初始化阶段 :
- 创建多个专家网络
- 创建路由器网络
-
前向传播阶段 :
- 输入:(batch_size, seq_len, hidden_dim)的张量
- 路由:为每个token计算应该由哪些专家处理
- 分发:将token分发给对应的专家处理
- 聚合:将专家处理结果按权重聚合
- 输出:与输入相同形状的张量
-
为什么使用稀疏MoE :
- 增加模型参数量而不增加推理计算量
- 每个token只经过少量专家,而不是所有专家
- 不同专家可以学习处理不同类型的输入
- 在大型语言模型中,这种方法可以显著提高模型性能
关键技术点
- Top-K路由 :每个token只选择概率最高的k个专家
- 权重归一化 :确保选中专家的权重和为1
- 稀疏计算 :每个专家只处理分配给它的token,而不是所有token
- 索引加法 :使用 index_add_ 高效地将专家输出聚合到最终结果
这种实现方式是对Switch Transformers和GShard等论文中提出的MoE方法的简化实现,适合用于理解MoE的基本原理和工作流程。
稀疏混合专家模型(SparseMOE)的数字化流程详解
我将用具体的数字示例来详细阐述SparseMOE的整个工作流程,这样能更直观地理解代码的执行过程。
1. 输入数据
假设我们有以下输入:
- 批次大小(batch_size): 2
- 序列长度(seq_len): 4
- 隐藏层维度(hidden_dim): 16
- 专家数量(expert_number): 2
- 每个token选择的专家数量(top_k): 2
输入张量x的形状为: (2, 4, 16),表示2个样本,每个样本有4个token,每个token是16维向量。
2. 数据预处理
batch_size, seq_len, hidden_dim = x.size() # (2, 4, 16)
hidden_states = x.view(-1, hidden_dim) # (8, 16)
将输入展平为(8, 16)的二维张量,其中8 = 2×4,代表总共8个token。
3. 路由过程
3.1 计算路由logits
router_logits = self.gate(hidden_states) # (8, 2)
假设gate层的权重矩阵为W(16×2),偏置为b(2),则:
- 对于第1个token: logits_1 = hidden_states_1 × W + b = [0.5, 1.2]
- 对于第2个token: logits_2 = hidden_states_2 × W + b = [1.0, 0.3]
- …
- 对于第8个token: logits_8 = hidden_states_8 × W + b = [0.7, 0.9]
router_logits的形状为(8, 2),表示8个token对2个专家的原始分数。
3.2 计算路由概率
routing_probs = F.softmax(router_logits, dim=-1, dtype=torch.float) # (8, 2)
对logits应用softmax函数:
- 对于第1个token: probs_1 = softmax([0.5, 1.2]) = [0.32, 0.68]
- 对于第2个token: probs_2 = softmax([1.0, 0.3]) = [0.67, 0.33]
- …
- 对于第8个token: probs_8 = softmax([0.7, 0.9]) = [0.45, 0.55]
routing_probs的形状为(8, 2),表示8个token选择2个专家的概率。
3.3 选择Top-K专家
router_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1) # (8, 2)
由于top_k=2且专家数量也是2,所以每个token会选择所有专家:
- 对于第1个token: weights_1 = [0.32, 0.68], experts_1 = [0, 1]
- 对于第2个token: weights_2 = [0.67, 0.33], experts_2 = [0, 1]
- …
- 对于第8个token: weights_8 = [0.45, 0.55], experts_8 = [0, 1]
router_weights的形状为(8, 2),表示选中专家的权重。
selected_experts的形状为(8, 2),表示选中的专家索引。
3.4 权重归一化
router_weights = router_weights / router_weights.sum(dim=-1, keepdim=True) # (8, 2)
由于我们选择了所有专家,权重和已经为1,所以归一化不会改变权重值。
3.5 生成专家掩码
expert_mask = F.one_hot(selected_experts, num_classes=self.expert_number) # (8, 2, 2)
expert_mask = expert_mask.permute(2, 1, 0) # (2, 2, 8)
对于selected_experts,生成one-hot编码:
- 对于第1个token: mask_1 = [[[1, 0], [0, 1]]]
- 对于第2个token: mask_2 = [[[1, 0], [0, 1]]]
- …
经过permute后,expert_mask的形状为(2, 2, 8),表示2个专家,每个专家处理哪些token。
4. 专家处理
首先创建输出张量:
final_hidden_states = torch.zeros((8, 16), dtype=hidden_states.dtype, device=hidden_states.device)
4.1 专家0的处理
expert_idx = 0
expert_layer = self.experts[0]
idx, top_x = torch.where(expert_mask[0]) # idx=[0,0,...], top_x=[0,1,...,7]
expert_mask[0]表示专家0处理的token,由于我们选择了所有专家,所以专家0处理所有8个token。
current_state = hidden_states.unsqueeze(0)[:, top_x, :].reshape(-1, hidden_dim) # (8, 16)
提取专家0需要处理的token特征。
current_hidden_states = expert_layer(current_state) * router_weights[top_x, idx].unsqueeze(-1) # (8, 16)
假设专家0的处理结果为:
- 对于第1个token: output_1 = [0.1, 0.2, …, 0.1] * 0.32 = [0.032, 0.064, …, 0.032]
- 对于第2个token: output_2 = [0.3, 0.1, …, 0.2] * 0.67 = [0.201, 0.067, …, 0.134]
- …
final_hidden_states.index_add_(0, top_x, current_hidden_states)
将专家0的处理结果加到final_hidden_states中。
4.2 专家1的处理
expert_idx = 1
expert_layer = self.experts[1]
idx, top_x = torch.where(expert_mask[1]) # idx=[0,0,...], top_x=[0,1,...,7]
同样,专家1也处理所有8个token。
current_state = hidden_states.unsqueeze(0)[:, top_x, :].reshape(-1, hidden_dim) # (8, 16)
current_hidden_states = expert_layer(current_state) * router_weights[top_x, idx].unsqueeze(-1) # (8, 16)
假设专家1的处理结果为:
- 对于第1个token: output_1 = [0.2, 0.3, …, 0.2] * 0.68 = [0.136, 0.204, …, 0.136]
- 对于第2个token: output_2 = [0.1, 0.2, …, 0.3] * 0.33 = [0.033, 0.066, …, 0.099]
- …
final_hidden_states.index_add_(0, top_x, current_hidden_states)
将专家1的处理结果加到final_hidden_states中。
5. 输出处理
final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim) # (2, 4, 16)
将结果重新整形为原始输入形状(2, 4, 16)。
最终,对于第1个token的输出是专家0和专家1的加权和:
- output_1 = [0.032, 0.064, …, 0.032] + [0.136, 0.204, …, 0.136] = [0.168, 0.268, …, 0.168]
6. 返回结果
return final_hidden_states, router_logits # (2, 4, 16), (8, 2)
返回最终的隐藏状态和路由logits。
数值流程总结
- 输入 : (2, 4, 16)的张量,表示2个样本,每个4个token,每个token16维
- 展平 : 变为(8, 16)的张量,8个token
- 路由 :
- 计算logits: (8, 2),每个token对每个专家的原始分数
- 计算概率: (8, 2),每个token选择每个专家的概率
- 选择专家: 每个token选择2个专家(本例中是所有专家)
- 专家处理 :
- 专家0处理所有token,权重不同
- 专家1处理所有token,权重不同
- 结果加权求和
- 输出 : (2, 4, 16)的张