DeepSeek DeepEP学习(一)low latency dispatch
背景
为了优化延迟,low lantency使用卡间直接收发cast成fp8的数据的方式,而不是使用normal算子的第一步执行机间同号卡网络发送,再通过nvlink进行转发的两阶段方式。进一步地,normal算子的dispatch包含了notify_dispatch传输meta信息和dispatch传输实际数据两个kernel,而low lantency也省去了notify的过程,为此需要的代价就是显存占用较高,而且也需要配合deepseek版本的gemm。
用法
以github中的demo为例,用法比较简单
- 首先通过get_low_latency_rdma_size_hint获取num_rdma_bytes,这里获取的就是之后Buffer需要开多大
- 然后初始化Buffer
- 执行low_latency_dispatch
def get_buffer(group: dist.ProcessGroup, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> Buffer:
# NOTES: the low-latency mode will consume much more space than the normal mode
# So we recommend that `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
global _buffer
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts)
# Allocate a buffer if not existed or not enough buffer size
if _buffer is None or _buffer.group != group or not _buffer.low_latency_mode or _buffer.num_rdma_bytes < num_rdma_bytes:
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
assert num_experts % group.size() == 0
_buffer = Buffer(group, 0, num_rdma_bytes, low_latency_mode=True, num_qps_per_rank=num_experts // group.size())
return _buffer
def low_latency_dispatch(hidden_states: torch.Tensor, topk_idx: torch.Tensor, num_max_dispatch_tokens_per_rank: int, num_experts: int):
global _buffer
# Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay)
recv_hidden_states, recv_expert_count, handle, event, hook = \
_buffer.low_latency_dispatch(hidden_states, topk_idx, num_max_dispatch_tokens_per_rank, num_experts,
async_finish=False, return_recv_hook=True)
# NOTES: the actual tensor will not be received only if you call `hook()`,
# it is useful for double-batch overlapping, but **without any SM occupation**
# If you don't want to overlap, please set `return_recv_hook=False`
# Later, you can use our GEMM library to do the computation with this specific format
return recv_hidden_states, recv_expert_count, handle, event, hook
获取buffer大小
num_max_dispatch_tokens_per_rank表示一个rank最多要dispatch多少个token,由于要cast成fp8,因此一个token的大小为embedding的大小,即hidden,以及scale和token_idx,总大小为num_bytes_per_dispatch_msg。
这里只关注send_buffer_bytes,乘2是为了进行不同batch的并行,send_buffer_bytes比较简单,就是预留最大的token数的空间即可,即num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg。
由于low_latency使用直接收发的方式,不计算实际layout,因此对于recv的极端情况就是所有rank的token都会到同一个expert,所以对于recv,这里需要总的数据大小为num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg。
size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) {
auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes;
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES;
}
LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) {
const int num_scales = hidden / 128;
const int num_local_experts = num_experts / num_ranks;
EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden);
size_t num_bytes_per_dispatch_msg = hidden + num_scales * sizeof(float) + sizeof(int4);
size_t num_bytes_per_combine_msg = sizeof(int4) + hidden * sizeof(nv_bfloat16);
// Send buffer
size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes);
EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0);
total_bytes += send_buffer_bytes * 2;
size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes);
EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0);
total_bytes += recv_buffer_bytes * 2;
...
}
Buffer创建和初始化
首先创建一个Buffer,这里主要是申请一些内存,low_latency不用关注,可以先忽略
如果是low_latency,还需要开启ibgda,然后rank0通过get_local_nvshmem_unique_id获取nvshmem的handle,类似nccl的一个comm,然后所有rank通过allgather获取到这个handle。
def __init__(self, group: dist.ProcessGroup,
num_nvl_bytes: int = 0, num_rdma_bytes: int = 0,
low_latency_mode: bool = False, num_qps_per_rank: int = 1) -> None:
self.rank = group.rank()
self.group_size = group.size()
self.group = group
self.num_nvl_bytes = num_nvl_bytes
self.num_rdma_bytes = num_rdma_bytes
self.low_latency_mode = low_latency_mode
self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode)
root_unique_id = None
if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode:
# Enable IBGDA for the low latency mode, which refers to "no package forwarding between NVLink and RDMA"
if low_latency_mode:
assert num_qps_per_rank > 0
os.environ['NVSHMEM_DISABLE_P2P'] = '1'
os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1'
os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu'
os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}'
# Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check
os.environ['NVSHMEM_QP_DEPTH'] = '1024'
# NOTES: NVSHMEM initialization requires at least 256 MiB
os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}'
# NOTES: make sure AR (Adaptive Routing) is turned off while running normal kernels, as we cannot verify AR status in the code
# Synchronize using the root ID
nvshmem_unique_ids = [None, ] * self.group_size
if (low_latency_mode and self.rank == 0) or (not low_latency_mode and self.runtime.get_rdma_rank() == 0):
root_unique_id = self.runtime.get_local_nvshmem_unique_id()
dist.all_gather_object(nvshmem_unique_ids, root_unique_id, group)
root_unique_id = nvshmem_unique_ids[0 if low_latency_mode else self.runtime.get_root_rdma_rank(True)]
# Make CPP runtime available
self.runtime.sync(device_ids, ipc_handles, root_unique_id)
assert self.runtime.is_available()
然后执行sync,由于num_nvl_bytes为0,因此这里先略去相关代码。然后执行init,就是初始化nvshmem,此时机间的rdma连接已经完成建立。
void Buffer::sync(const std::vector<int> &device_ids,
const std::vector<std::optional<pybind11::bytearray>> &all_gathered_handles,
const std::optional<pybind11::bytearray>& root_unique_id_opt) {
EP_HOST_ASSERT(not is_available());
// Sync NVSHMEM handles and allocate memory
if (num_rdma_bytes > 0) {
// Initialize NVSHMEM
EP_HOST_ASSERT(root_unique_id_opt.has_value());
std::vector<uint8_t> root_unique_id(root_unique_id_opt->size());
auto root_unique_id_str = root_unique_id_opt->cast<std::string>();
std::memcpy(root_unique_id.data(), root_unique_id_str.c_str(), root_unique_id_opt->size());
auto nvshmem_rank = low_latency_mode ? rank : rdma_rank;
auto num_nvshmem_ranks = low_latency_mode ? num_ranks : num_rdma_ranks;
EP_HOST_ASSERT(nvshmem_rank == internode::init(root_unique_id, nvshmem_rank, num_nvshmem_ranks, low_latency_mode));
internode::barrier();
// Allocate
rdma_buffer_ptr = internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES);
// Clean buffer (mainly for low-latency mode)
CUDA_CHECK(cudaMemset(rdma_buffer_ptr, 0, num_rdma_bytes));
// Barrier
internode::barrier();
CUDA_CHECK(cudaDeviceSynchronize());
}
// Ready to use
available = true;
}
然后执行alloc,即通过nvshmem_align分配nvshmem的对称内存。执行ibgda_initialize_recv_queue kernel,初始化recv连接,然后post recv wr。
dispatch
下图为一个例子,两个节点,每个节点两张卡,每个卡上两个expert。
角色分配
dispatch kernel对warp的角色进行了分配,除了数据发送之外的逻辑,warp的角色分配都是一样的,每个sm包含了kNumWarpGroups个warp组,每组有kNumWarpsPerGroup个warp,每个warp group对应一个expert。
所以对应expert的计算方式为
const auto warp_group_id = warp_id / kNumWarpsPerGroup;
const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
数据发送
在数据发送的过程中线程角色分配有些不同,如下图所示,每个sm内部的最后一个warp负责count的发送,假设叫count warp,其他的warp负责数据的发送,假设叫data warp。
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void dispatch() {
...
if (warp_id < num_warps - 1) {
constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16);
EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization");
const auto num_threads = (num_warps - 1) * 32;
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead;
for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
const auto x_int4 = reinterpret_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
const auto rdma_x_int2 = reinterpret_cast<int2*>(reinterpret_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_int2) + kHidden);
const auto rdma_x_src_idx = reinterpret_cast<int*>(rdma_x_scales + num_scales);
...
}
}
...
}
当前sm读取到第token_idx的token,该token的embedding为x_int4 。如上所述,sm内部的num_topk个warp负责读取topk矩阵,获取当前token将被发送到哪个expert,rdma_x_int2是rdma_x中用于存储转换为fp8之后embedding的buffer。
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void dispatch() {
...
if (dst_expert_idx >= 0) {
int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0;
slot_idx = __shfl_sync(0xffffffff, slot_idx, 0);
const auto dst_rank = dst_expert_idx / num_local_experts;
const auto dst_expert_local_idx = dst_expert_idx % num_local_experts;
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_int2);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) +
dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
slot_idx * num_bytes_per_msg;
if (dst_rank != rank) {
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx);
} else {
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
}
// Increase counter after finishing
__syncwarp();
lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;
}
...
}
dst_expert_idx为该token需要被发送到哪个expert,由于不同的sm处理不同的token,可能会同时向某个expert发送数据,因此该warp的laneid 0线程通过原子加一的方式获取发送到该expert的slot_id,避免内存覆盖,然后shfl_sync广播slot_id到当前warp所有线程。
每个rank的recv_buffer如下图所示,每个蓝色矩形的大小为num_max_dispatch_tokens_per_rank * num_bytes_per_msg,假设当前为rank2,需要发送给ep3,因此对端为rank1,dst_expert_local_idx为1。
发送的同步
由于通过rdma write发送数据,对端无法感知,因此需要在rdma_write之后再执行一个rdma_write_with_imm让recv端感知到数据的发送,但是由于数据的发送是多个sm一起进行的,sm之间无法通过sync之类的接口进行同步,因此DeepEP设计了一套同步机制,在数据发送结束之后,以及初始化一些内存这两件事做完之后向对端发送rdma_write_with_imm,其中发送由data warp执行,初始化的事情由sm0的count warp执行。
前边有提到过atomic_finish_counter_per_expert,atomic_finish_counter_per_expert[x]表示第x个expert的处理进度到了哪里,初始化为0。data warp每次执行一个发送,会将atomic_finish_counter_per_expert[x]原子 + 1;count warp初始化完成之后会原子 + FINISHED_SUM_TAG,然后计算当前rank需要向expert[x]发送token的总数sum,然后原子 + (FINISHED_SUM_TAG - sum),FINISHED_SUM_TAG一定大于sum,因此无论两个warp执行顺序如何,当轮询到这个值等于2 * FINISHED_SUM_TAG的时候,数据发送和内存初始化的工作就完成了。
然后看count warp的逻辑,sm0会执行next_clean的初始化,完成之后将atomic_finish_counter_per_expert加上FINISHED_SUM_TAG。
然后开始计算发送到每个expert的token数量,一个sm负责了kNumWarpGroups个expert,通过smid可以知道自己负责的expert区间,然后开始遍历topk_idx,看有多少个token是属于自己维护的expert,维护sum到expert_count[kNumWarpGroups],expert_count[x]表示当前sm会发送到这个expert多少个token。每个线程计算完成自己的部分之后,通过warp_reduce_sum拿到warp整体的和,就完成了自己维护的expert对应的token数的总和,如果lane_id为0,那么将自己这个总和原子加到atomic_finish_counter_per_expert,值为FINISHED_SUM_TAG - sum。
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void dispatch() {
...
int expert_count[kNumWarpGroups] = {0};
const auto expert_begin_idx = sm_id * kNumWarpGroups;
const auto expert_end_idx = min(expert_begin_idx + kNumWarpGroups, num_experts);
// Per lane count
#pragma unroll 8
for (int i = lane_id; i < num_tokens * num_topk; i += 32) {
auto idx = static_cast<int>(__ldg(topk_idx + i));
if (idx >= expert_begin_idx and idx < expert_end_idx)
expert_count[idx - expert_begin_idx] ++;
}
// Warp reduce
#pragma unroll
for (int i = expert_begin_idx; i < expert_end_idx; ++ i) {
auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]);
if (lane_id == 0) {
shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum;
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum);
}
}
...
}
数据发送完成之后,开始通知其他rank自己已经完成了,每个warp group的warp 0的第一个线程开始轮询atomic_finish_counter_per_expert,如果轮询到FINISHED_SUM_TAG * 2,说明数据发送结束,内存清理也完成了,那么执行nvshmemi_ibgda_rma_p通过rdma write with imm将实际发送的数据量num_tokens_sent写过去。
通知recv端
还是和之前一样,每个warpgroup维护一个expert,该group的warp0的线程0会执行通知recv端的逻辑,首先轮询atomic_finish_counter_per_expert,看自己维护对应的第responsible_expert_idx个expert的数据有没有发送完成,如果等于FINISHED_SUM_TAG * 2,说明发送完成了,那么通过nvshmemi_ibgda_rma_p执行rdma_write_with_imm将发送的token数num_tokens_sent发送过去,然后再下发recv wr。
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void dispatch() {
...
// Issue count sends
if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) {
const auto dst_rank = responsible_expert_idx / num_local_experts;
const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts;
const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * kNumWarpGroups];
// Wait local sends issued and send expert counts
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
if (dst_rank != rank) {
nvshmemi_ibgda_rma_p(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1,
dst_rank, dst_expert_local_idx, 0);
nvshmemi_ibgda_prepare_recvs(dst_rank, dst_expert_local_idx);
} else {
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1);
}
// Clean workspace for next use
atomic_counter_per_expert[responsible_expert_idx] = 0;
atomic_finish_counter_per_expert[responsible_expert_idx] = 0;
}
...
}
数据接收
然后开始执行数据接收的过程,对于接收侧,如图5 recv buffer的结构,同样是一个warp group对应一个expert,负责处理recv buffer中的一块数据,每个warp group的warp 1的线程0执行poll cq的过程,这里warp1和执行通知recv端的warp0可以overlap,等待该expert对应的rank的write with imm将数据发送过来,轮询到cq之后,读取需要接收的token数num_recv_tokens。
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void dispatch() {
...
// Shared between sub-warps in warp groups
__shared__ int shared_num_recv_tokens[kNumWarpGroups], shared_recv_token_begin_idx[kNumWarpGroups];
// Wait tokens to arrive
// NOTES: using sub-warp 1 to overlap with sub-warp 0
int num_recv_tokens, recv_token_begin_idx;
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
if (sub_warp_id == 1 and lane_id == 0) {
if (src_rank != rank) {
nvshmemi_ibgda_poll_recv(src_rank, local_expert_idx);
num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank);
EP_DEVICE_ASSERT(num_recv_tokens != 0);
} else {
while ((num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
}
num_recv_tokens = -num_recv_tokens - 1;
recv_token_begin_idx = atomicAdd(atomic_counter_per_local_expert + local_expert_idx, num_recv_tokens);
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
}
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32));
num_recv_tokens = shared_num_recv_tokens[warp_group_id];
recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];
...
}
拿到num_recv_tokens之后,会通过atomic add的方式加到atomic_counter_per_local_expert[local_expert_idx]上边,预留了num_recv_tokens的空间用于存储收到的数据。
然后开始拷贝数据,每一个warp对应一个token。recv_x_int4就是输出的地址,recv_token_begin_idx就是通过原子加预留的位置,src的算法和发送过程中一致,通过responsible_expert_idx可以知道src_rank,是下图中黄色位置。拷贝完成数据之后,将src的token_idx填充到recv_src_info。
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void dispatch() {
...
for (int i = sub_warp_id; i < num_recv_tokens; i += kNumWarpsPerGroup) {
// Copy data
// NOTES: only 2 load iterations for 7K hidden with 7 unrolls
const auto src = reinterpret_cast<int4*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
const auto dst = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;
UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst, src, ld_nc_global, st_na_global);
// Copy scales
const auto src_scales = reinterpret_cast<float*>(rdma_recv_x_uint8 + i * num_bytes_per_msg + kHidden);
const auto dst_scales = reinterpret_cast<float*>(recv_x_scales + recv_token_begin_idx + i);
const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank;
auto scale_0 = lane_id < num_scales ? ld_nc_global(src_scales + lane_id) : 0;
auto scale_1 = (lane_id + 32) < num_scales ? ld_nc_global(src_scales + lane_id + 32) : 0;
lane_id < num_scales ? dst_scales[lane_id * scale_stride] = scale_0 : 0.0f;
(lane_id + 32) < num_scales ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 : 0.0f;
// Copy source info
const auto src_src_idx = reinterpret_cast<int*>(src_scales + num_scales);
if (lane_id == 0)
recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx);
__syncwarp();
}
}