ShareExpert SparseMoE的学习
根据您提供的查询和图片描述,我将用中文详细解释 ShareExpertMOE
类及其训练代码,并结合 DeepSeekMoE 架构图进行说明。以下内容将完整且独立,帮助您理解模型的实现和架构。
ShareExpertMOE 类详解
ShareExpertMOE
是一个基于 PyTorch 的模块,实现了混合专家(Mixture of Experts, MoE)模型。它结合了稀疏专家(通过路由选择性地处理输入)和共享专家(对所有输入 token 进行统一处理),这种设计提高了 Transformer 架构在序列数据处理中的灵活性和效率。
初始化 (__init__
)
class ShareExpertMOE(nn.Module):
def __init__(self, config):
super().__init__()
self.moe_model = SparseMOE(config)
self.shared_experts = nn.ModuleList(
[
BasicExpert(
config.hidden_dim, config.hidden_dim
) for _ in range(config.shared_experts_number)
]
)
- 目的:定义模型结构并初始化核心组件。
- 参数:
config
:一个配置对象(MOEConfig
),包含超参数,例如:hidden_dim
:隐藏状态的维度。expert_number
:稀疏专家的数量。top_k
:每个 token 选择的专家数量。shared_experts_number
:共享专家的数量。
- 组件:
self.moe_model
:一个SparseMOE
实例,负责通过路由机制将 token 分配给特定专家,对应架构图中的 Routed Experts(蓝色矩形)。self.shared_experts
:一个nn.ModuleList
,包含多个BasicExpert
实例,每个专家是一个从hidden_dim
到hidden_dim
的简单神经网络(例如全连接层),对应架构图中的 Shared Experts(绿色矩形)。
- 细节:共享专家对所有 token 进行处理,而稀疏专家只处理路由器分配的 token,这种混合设计增强了模型的表达能力。
前向传播 (forward
)
def forward(self, x):
# x shape 是 (b, s, hidden_dim)
# 首先过 moe 模型
sparse_moe_out, router_logits = self.moe_model(x)
# 然后过 shared experts
shared_experts_out = [
expert(x) for expert in self.shared_experts
] # 每一个 expert 的输出 shape 是 (b, s, hidden_dim)
shared_experts_out = torch.stack(
shared_experts_out, dim=0
).sum(dim=0)
# 把 sparse_moe_out 和 shared_experts_out 加起来
return sparse_moe_out + shared_experts_out, router_logits
- 输入:
x
:形状为(b, s, hidden_dim)
的张量,其中:b
:批次大小。s
:序列长度。hidden_dim
:隐藏状态维度。- 对应架构图中的 Input Hidden ( u_t )。
- 处理流程:
- 稀疏 MoE 处理:
sparse_moe_out, router_logits = self.moe_model(x)
:sparse_moe_out
:稀疏专家的输出,形状为(b, s, hidden_dim)
。router_logits
:路由器的原始得分,形状为(b * s, expert_number)
,用于决定哪些专家处理哪些 token。- 对应架构图中的 Router 和 Routed Experts(蓝色)。
- 共享专家处理:
shared_experts_out = [expert(x) for expert in self.shared_experts]
:每个共享专家处理整个输入x
,输出形状为(b, s, hidden_dim)
。torch.stack(shared_experts_out, dim=0).sum(dim=0)
:将所有共享专家的输出堆叠并沿专家维度求和,最终形状仍为(b, s, hidden_dim)
。- 对应架构图中的 Shared Experts(绿色)。
- 合并输出:
sparse_moe_out + shared_experts_out
:将稀疏专家和共享专家的输出相加,生成最终隐藏状态,对应架构图中的 Output Hidden ( \hat{h}_t )。
- 稀疏 MoE 处理:
- 输出:
- 一个元组:
(sparse_moe_out + shared_experts_out, router_logits)
。- 第一个元素:最终隐藏状态,形状为
(b, s, hidden_dim)
。 - 第二个元素:路由器得分,形状为
(b * s, expert_number)
,用于后续训练(如负载均衡)。
- 第一个元素:最终隐藏状态,形状为
- 一个元组:
测试函数 (test_share_expert_moe
)
def test_share_expert_moe():
x = torch.rand(2, 4, 16)
config = MOEConfig(16, 2, 2)
share_expert_moe = ShareExpertMOE(config)
out = share_expert_moe(x)
print(out[0].shape, out[1].shape)
test_share_expert_moe()
- 目的:验证
ShareExpertMOE
的功能是否正确。 - 步骤:
- 输入:
x = torch.rand(2, 4, 16)
,生成随机张量,批次大小为 2,序列长度为 4,隐藏维度为 16。 - 配置:
config = MOEConfig(16, 2, 2)
,设置hidden_dim=16
、expert_number=2
、top_k=2
。 - 模型:实例化
ShareExpertMOE
。 - 输出:
out[0].shape
:(2, 4, 16)
,最终隐藏状态。out[1].shape
:(8, 2)
,路由器得分(b * s = 2 * 4 = 8
,expert_number=2
)。
- 输入:
- 意义:验证了从输入 ( u_t ) 到输出 ( \hat{h}_t ) 的流程与架构图一致。
训练代码详解
负载均衡损失 (switch_load_balancing_loss
)
def switch_load_balancing_loss(router_logits: torch.Tensor, num_experts: int) -> torch.Tensor:
router_probs = torch.softmax(router_logits, dim=-1)
_, selected_experts = torch.topk(router_probs, k=2, dim=-1)
mask = torch.nn.functional.one_hot(selected_experts, num_experts).float()
expected_load = torch.ones_like(router_probs) / num_experts
actual_load = mask.mean(dim=0)
aux_loss = torch.sum(actual_load * router_probs.mean(dim=0)) * num_experts
z_loss = torch.mean(torch.square(router_logits))
z_loss_weight = 0.001
total_loss = aux_loss + z_loss * z_loss_weight
return total_loss
- 目的:确保稀疏专家的负载均衡,避免某些专家被过度使用或闲置。
- 输入:
router_logits
:形状为(b * s, num_experts)
的路由器得分。num_experts
:稀疏专家总数。
- 步骤:
- 概率计算:
router_probs = torch.softmax(router_logits, dim=-1)
,将得分转为概率。 - Top-K 选择:选择每个 token 的 top-2 专家。
- 掩码:生成 one-hot 掩码,表示实际选择的专家。
- 负载计算:
- 预期负载:均匀分布(
1/num_experts
)。 - 实际负载:每个专家被选择的平均比例。
- 预期负载:均匀分布(
- 损失:
aux_loss
:惩罚负载不均衡。z_loss
:惩罚过大的路由得分,权重为 0.001。- 总损失:两者加权和。
- 概率计算:
- 架构图关联:对应 Router 的 Top-K 机制和黄色虚线,表示负载均衡的优化目标。
训练循环 (test_moe_training
)
def test_moe_training():
batch_size = 32
seq_len = 16
hidden_dim = 32
num_batches = 100
config = MOEConfig(hidden_dim=32, expert_number=4, top_k=2, shared_experts_number=2)
model = ShareExpertMOE(config)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model.train()
for batch in range(num_batches):
x = torch.randn(batch_size, seq_len, hidden_dim)
target = torch.randn(batch_size, seq_len, hidden_dim)
output, router_logits = model(x)
mse_loss = F.mse_loss(output, target)
aux_loss = switch_load_balancing_loss(router_logits, config.expert_number)
total_loss = mse_loss + 0.01 * aux_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
if batch % 10 == 0:
print(f"Batch {batch}, Loss: {total_loss.item():.4f} "
f"(MSE: {mse_loss.item():.4f}, Aux: {aux_loss.item():.4f})")
test_moe_training()
- 目的:训练
ShareExpertMOE
模型。 - 设置:
- 超参数:
batch_size=32
、seq_len=16
、hidden_dim=32
、num_batches=100
。 - 配置:
expert_number=4
、top_k=2
、shared_experts_number=2
。 - 优化器:Adam,学习率 0.001。
- 超参数:
- 训练过程:
- 数据:随机生成输入
x
和目标target
,均为(32, 16, 32)
。 - 前向传播:计算输出和路由得分。
- 损失:
mse_loss
:输出与目标的均方误差。aux_loss
:负载均衡损失。total_loss
:mse_loss + 0.01 * aux_loss
。
- 优化:反向传播并更新参数。
- 数据:随机生成输入
- 输出:每 10 个 batch 打印损失,显示训练进展。
DeepSeekMoE 架构图与代码对应
根据图片描述,DeepSeekMoE 架构图展示了 Transformer 和 MoE 的集成。以下是代码与图的对应关系:
Transformer Block(左侧)
- 输入 ( h_t ):经过 RMS Norm 和 Multi-Head Attention 处理,生成 ( u_t ),对应代码中的输入
x
。 - 组件:前馈网络(Feed-Forward Network)、RMS Norm 和注意力机制是 Transformer 的标准层,为 MoE 准备输入。
DeepSeekMoE 模块(右侧)
- 输入 ( u_t ):进入 MoE 系统(
x
)。 - Router:通过 Top-K 机制分配 token 给 Routed Experts(蓝色),对应
self.moe_model
。 - Routed Experts:生成
sparse_moe_out
。 - Shared Experts(绿色):处理所有 token,生成
shared_experts_out
。 - 输出 ( \hat{h}_t ):合并输出(
sparse_moe_out + shared_experts_out
)。 - 负载均衡:黄色虚线与
switch_load_balancing_loss
相关。
总结
ShareExpertMOE
:通过稀疏和共享专家的结合,实现了高效的 token 处理,前向传播输出最终隐藏状态和路由得分。- 训练:结合 MSE 损失和负载均衡损失,确保预测准确性和专家利用率。
- 架构图:清晰展示了 Transformer 和 DeepSeekMoE 的数据流,与代码逻辑高度一致。
希望这个解释对您理解模型和代码有所帮助!如果有进一步的问题,请随时提问。