DeepEP原理分析(二):low latency模式代码详解(更新中)


最近在做基于 DeepEP 和 SGLang 的推理优化。DeepEP 是一个用于取代 MoE 模型中 AlltoAll 操作、非常精妙和细致的点对点通信 kernel,好好研读其实现过程对 CUDA 编程、MoE rank 间通信都有巨大好处(下面的分析基于0625 main版本,详细代码请自行上 github 上搜索)

1 dispatch

1.1 入参详解

1、接收与打包参数:主要用于在目标 GPU 上存储从其他 GPU 接收到的数据

- packed_recv_x:用于存储接收到的已打包好的 token 特征数据,这些数据将直接用于本地专家的计算。这是数据在接收方的最终存储位置

- packed_recv_x_scales:如果使用 FP8 量化传输,这个指针指向用于存储 packed_recv_x 对应 FP8 scale(缩放因子)的缓冲区

- packed_recv_src_info:存储每个接收到的 token 在其来源 GPU 上的原始索引 。这个信息在后续 combine 阶段至关重要,用于将计算结果正确地写回

- packed_recv_layout_range:记录由每个本地专家从各个源 rank 接收到的 token 的布局信息。它通常将来自某个源 rank 的 token 总数和 token 在 packed_recv_x 缓冲区中的起始偏移量打包成一个 64 位整数,其结构可参考下面这个二维数组示意图,这个结构使得 dispatch kernel 能够记录每个本地专家从哪个源 rank 接收了多少 token,以及这些 token 在接收缓冲区中的具体位置。随后 combine kernel 可以利用这些信息,将处理后的 token 结果正确地发送回它们最初的源 rank


- packed_recv_count:原子计数器,每个本地专家都有一个属于自己的该计数器,用来追踪该专家已经接收了多少 token,以便为新到达的 token 计算正确的存储偏移量

- cumulative_local_expert_recv_stats:用于累计每个本地专家接收到的 token 总数,作为统计信息

2、RDMA 缓冲区参数:RDMA 通信过程所需的空间和信号

- rdma_recv_x:RDMA 接收缓冲区。从其他节点通过网络发送过来的数据会先被存放在这里,之后再由接收端 kernel 解包并拷贝到 packed_recv_x 中

- rdma_recv_count:用于节点间通信的 计数器/标志位数组。发送方会通过原子操作更新这个数组中的值,以通知接收方它发送了多少个 token。接收方通过轮询这个值来判断数据是否到达以及到达了多少

- rdma_x:RDMA 发送缓冲区。在发送数据前,token 会在这里被准备好(例如进行 FP 8 类型转换),然后从这个地址启动 RDMA 传输

3、输入数据与路由信息

- x:指向当前 GPU 上的 MoE 层输入 token 的特征张量

- topk_idx:gating network 的输出结果,包含了每个 token 应该被发送到的 top-k 个专家的索引

4、kernel 同步与协调参数

- atomic_counter_per_expert:原子计数器工作区。当一个 warp 准备发送 token 到某个专家时,它会原子地增加对应专家的计数器,从而获得一个唯一的 "slot"(槽位)索引,以避免在写入 rdma_recv_x 缓冲区时发生冲突

- atomic_finish_counter_per_expert:一个原子计数器工作区,用于更复杂的同步。它被用来确保一个 token 的所有数据都准备好发送后,才更新最终的计数值,并协调发送 token 数量的逻辑

- next_clean:用于流水线执行时的缓冲区管理。next_clean 指向下一次迭代需要使用的缓冲区,本内核会负责将其清零,为下一次 MoE 计算做准备

- num_next_clean_int:next_clean 所指向的缓冲区的大小

5、MoE 模型结构和分布式环境参数

- num_tokens:当前 GPU 上的输入 token 总数

- num_max_dispatch_tokens_per_rank:单个 GPU 会发送给另一个 GPU 的最大 token 数量,用于预分配缓冲区大小

- num_topk:Top-k 路由中的 k 值

- num_experts:专家总数(涵盖所有 GPU)

- rank 和 num_ranks:当前 GPU 的 rank编号和分布式组中的 GPU 总数

- num_warp_groups 和 num_warps_per_group:决定 warp 如何分组来分工处理不同的专家,以及每个 group 内的 warp 数量

- round_scale:FP8 量化时是否使用四舍五入

- phases:控制 kernel 需要执行的阶段。可以是 LOW_LATENCY_SEND_PHASE (只发送)、 LOW_LATENCY_RECV_PHASE (只接收),或两者皆有。这使得同一个内核可以灵活地用于不同场景

1.2 变量初始化

1、量化有关的变量

- kNumPerChannels:在 FP8 量化中,通常不是为整个 kHidden 维度的张量计算一个单一的缩放因子(scale),而是将其划分为多个组(或通道),为每个组独立计算一个缩放因子。这种“per-channel”或“per-group”的量化方式可以更精确地保留原始数据的动态范围,从而提高量化后的模型精度。这里设定每个量化通道包含 128 个元素

- num_scales:计算总共需要多少个 FP8 缩放因子

- hidden_bytes:计算一个 token 的隐藏层向量(纯数据,不包括 metadata)在内存中占用的总字节数

- hidden_int4:将隐藏层向量占用的总字节数转换为以 int4 为单位的数量。int4 是一个 128 位(16 字节)的向量数据类型,通常用于在 CUDA kernel 中实现高效的向量化内存读写

- num_bytes_per_msg:在节点间进行 AlltoAll 通信时,单个 token 所打包成的 message 的总字节大小。这么做是为了将一个 token 的所有相关信息打包在一起,以便通过 RDMA 或 P2P 高效地发送给目标专家所在的 GPU

2、专家有关的变量

- shared_num_tokens_sent_per_expert:在同一个 SM 内,负责统计 token 数量的 warp 将计算结果传递给负责发送这些计数的 warp

- 写入:负责计数的 warp 在计算出每个它负责的专家需要接收的 token 总数后(通过 warp_reduce_sum ),会将这个总数写入到 shared_num_tokens_sent_per_expert 数组中的对应位置上

- 读取:每个 warp group 中 sub_warp_id == 0 且 lane_id == 0 的那个线程(即每个 warp group 的领导线程)会从 shared_num_tokens_sent_per_expert 中读取之前由计数 warp 存入的、对应位置上的 token 数量。该线程随后负责将这个 token 数量通过 nvshmemi_ibgda_amo_nonfetch_add 或 st_release_sys_global 发送给目标 Rank,以通知对方它即将接收多少个 token。

1.3 发送阶段

1.3.1 一个 SM 内的 warp 角色分配

1、负责数据处理的 warp (warp_id < num_warps - 1):这些 warp 负责读取原始 token 数据,进行 FP8 转换(如果需要),并将准备好的数据通过 RDMA 发送给目标专家

2、负责计数的 warp (warp_id == num_warps - 1):最后一个 warp 不处理 token 数据,它的任务是遍历 topk_idx 数组,统计出当前 SM 负责的那些专家(responsible_expert_idx)总共需要接收多少个 token ^9de6af

1.3.2 负责数据发送的 warp 涉及的主要流程

1、token 分配:每个 SM 负责处理一部分 token,通过 token_idx += num_sms 的步长实现网格跨步循环(grid-stride loop),确保所有 token 都能被处理,并且负载均衡

2、从 token 所在缓冲区中获取 token 元信息

- token 数据的张量

- token 在发送缓冲区(rdma_x)中的起始位置

- 存放 token 特征数据的位置

- 存放 FP8 缩放因子的位置

3、目标专家确定与 token 元数据写入

- 当前 SM 中的每个 warp 负责处理一个 top-k 专家,并通过 __ ldg (Load Global Device)高效地从 topk_idx 数组中读取当前 token 应该被发送到的目标专家全局索引 dst_expert_idx 。如果一个 warp 的 warp_id 大于等于 num_topk ,它就不需要发送这个 token, dst_expert_idx 被设为 -1

- 由 SM 中的第一个线程将当前 token 的索引 token_idx 写入发送缓冲区 rdma_x 对应的位置中。这个信息在 combine 阶段至关重要,用于将计算结果正确地放回其原始位置

4、数据转换:将输入的 bfloat16 数据转换为 fp8 (如果 kUseFP8 为 true)或者直接拷贝

- 当前 SM 中的所有线程协同处理一个 token 的整个 kHidden 维度(处理完一个 token 需要 hidden_bf16_int4 次)

- 如果 kUseFP 8 为 true 则进行 BF16 -> FP8 转换

- 加载数据 : auto int 4_value = __ ldg (x_int 4 + i); 使用向量加载指令 ldg 高效读取 128-bit (int 4) 的数据

- 计算 amax : 在 warp 内部,每个线程计算自己负责的几个 bfloat 16 值的绝对值的最大值 (amax)

- Warp-level amax reduction : amax = half_warp_reduce_max (amax); 使用 half_warp_reduce_max 在半个 warp (16 个线程) 内进行高效的 amax 归约。

- 计算 Scale : calculate_fp 8_scales (...) 根据归约后的 amax 计算出用于量化的缩放因子 scale 和用于反量化的 scale_inv

- 存储 Scale : lane_id == 0 or lane_id == 16 的线程将计算出的 scale_inv 写入发送缓冲区的 rdma_x_scales 部分

- 量化和存储 : 每个线程用 scale 将自己的 fp 32 值量化为 fp 8 ,并打包成 __ nv_fp 8 x 2_storage_t 类型,最后以 int 2 的形式写入发送缓冲区 rdma_x_vec

- SM 级别的 barrier 同步:调用 asm volatile("bar.sync1, %0;" :: "r"(num_threads)) 来确保 SM 内的所有线程都完成了对当前 token 的数据转换和写入发送缓冲区的操作,之后才能进入发送阶段。bar.sync1 是一个有编号的 barrier,避免与 kernel 中其他 barrier 混淆

5、数据发送:如果当前 warp 确定了有效的目标专家后就会执行发送操作

- 获取发送槽位 (slot_idx):atomicAdd 用于为当前 token 在目标专家的接收缓冲区中原子地申请一个槽位;而 __ shfl_sync 则将 lane_id == 0 的线程获取到的 slot_idx 广播给 warp 内的所有其他线程

- 计算目标地址 (dst_ptr) : 根据目标 rank、目标专家本地索引和 slot_idx ,精确计算出数据在远程 GPU 上的目标内存地址。这个地址结构是为了避免不同源 rank、不同 token 之间的写入冲突。

- 发送方式选择

- P2P (Peer-to-Peer) : nvshmemi_get_p2p_ptr 检查目标地址是否可以通过 P2P 直接访问。如果是 (dst_p2p_ptr != 0),说明源和目标 GPU 之间有高速的 NVLink 连接,可以直接通过 UNROLLED_WARP_COPY 进行一次内存拷贝(最高效的方式)

- IBGDA(GPUDirect-RDMA):如果 P2P 不可用,则使用 RDMA 的方式进行发送。 nvshmemi_ibgda_put_nbi_warp 会发起一个非阻塞的 RDMA put 操作,由整个 warp 协同将数据从本地发送缓冲区 rdma_x 推送到远程 GPU 的接收缓冲区 rdma_recv_x

- 更新完成计数器 : atomic_add_release_global(...) 在发送操作发起后,原子地增加目标专家的完成计数器。这个计数器用于确保在通知目的专家其要接收的 token 总数之前,所有发往该专家的 token 数据都已经成功发起了 RDMA 或 P2P 传输

1.3.3 负责控制发送行为的 warp(最后一个 warp)涉及的主要流程

1、该 warp 中的第一个 SM 需要负责(除了控制行为外):清理下一轮的缓冲区 (next_clean),将下一轮 dispatch 或 combine 要使用到的缓冲区清零

2、SM 负责统计要发往每个专家的 token 总数

- 每个 SM 被分配负责一部分专家(计算出当前 SM 负责的专家索引范围)

- Warp 内的所有 32 个线程同时并行不重复遍历整个 topk_idx 数组来获取 token 要发往的某些专家的索引。如果某个索引 idx 恰好落在当前 SM 的负责范围内,对应专家的本地计数器(expert_count)就加一

- 在每个线程完成各自的局部计数后,通过 warp_reduce_sum(底层调用的是 [[Warp Shuffle Functions]] 中的 shfl_xor_sync 函数)将结果汇总得到专家需要接收的 token 总数(因为每个 warp 都有一份属于自己的 expert_count,所以需要汇总)

- 第一个 lane 会将汇总的结果更新到 shared memory 中的某块空间,使 SM 内(也就是 block)的其他 warp 能够读取到这个统计结果;同时更新全局完成计数器

1.3.4 将发送统计数据通知目的 rank

1、每个 SM 中负责某个专家的 warp group 里,只有第一个 warp 的第一个线程来执行该通知任务,避免重复发送和清理

2、从 shared memory 中用来保存上一步汇总结果的内存中读取要发往给每个 SM 和每个 warp 所负责的专家的 token 总数

3、等待所有发往某个专家的 token 均已发起传输,且统计该专家 token 总数的汇总工作也已完成,来保证即将告知该专家要接收的 token 数量是准确的(此时数据已经通过 NVLink 或者 RDMA 在传输)。这里需要注意的是通知统计数据时发送的值并非正数而是一个负数(-num_tokens_sent - 1),接收方则可以通过 value = -value - 1 这一操作来解码得到原始的 num_tokens_sent 。使用负数可以很容易地区分“尚未收到”的状态(通常为0)和“已收到但数量为0”的状态(值为-1)

4、根据 1.3.2 中提到的的不同发送方式通过 P2P 或 RDMA 原子加操作将汇总结果写入对应的远端内存中

5、计数器清理

- 在发送完计数值之后,该线程会立即清理当前 rank 为这个专家所使用的两个原子计数器

- 如果目标 rank 是 rank 0 ,当前线程还会负责清理 rank 0 上的 packed_recv_count 。这可能是一个特殊的优化或约定,即由发送方来帮助 rank 0 完成一部分清理工作

评论

此博客中的热门博文

NVSHMEM官方文档部分内容总结

《笔记的方法》简单总结

DeepEP原理分析(一):low latency模式特点总结