Zero Bubble 论文分析

关键词:输入梯度,权重梯度,Pipeline Parallelism,Zero Bubble,1F1B

Deepseek-V3 technical report 中提到了一项很重要的训练优化技术:DualPipe,该技术的目的是为了尽可能地实现 computation 和 communication 两者的 overlapping,DS 作者提到 DualPipe 的设计参考了 ZeroBubble 策略,在 attention&MLP 梯度反向传播过程中区分 input backward 以及 weights backward。这一思想就是来源于 Zero Bubble 论文,这种细粒度梯度反向传播优化也让 Pipeline Parallelism 的效率得到了进一步的提升。

一、MLP 的反向传播过程

MLP(多层感知机)的反向传播过程可以分为两个部分:输入梯度计算(用 B 表示)和权重梯度计算(用 W 表示),前者为基于损失函数对上一层输出 x 进行微分后的结果,返回给上一层并用于该层的权重更新;后者用于本层权重的更新。

以往的设计中 B 和 W 被封装为同一个 backward function 提供给用户,这种设计对用户比较友好,且不会影响 DP 流程的效率(W 的通信和 B 的反向传播两者可以 overlap),但会影响 PP 流程的效率,因为上一层 B 的计算需要等待本层 W 计算完毕。因此 ZB 的做法是 splitting B & W 这两个流程。


二、ZB 涉及的 PP 流程和 1F1B 的区别

1、F 和 B 保持 sequentially dependent 的关系,但 W 可以灵活安排来尽可能让 W 的计算时间填补 pipeline bubbles

2、在假设 F/B/W 三个流程耗时均相同的情况下 1F1B、ZB-H1 和 ZB-H2 三者对比如下图所示

  • ZB-H1:Memory efficient schedule,B 先于 W 从而保证在 maximum peak memory usage 不超过 1F1B 的情况下 bubble size 减少为 1F1B 的三分之一

  • ZB-H2:Zero bubble schedule,在 warm-up 阶段引入更多的 micro-batches 来填补 ZB-H1 中仍然存在的 bubble

三、定量分析 

1、参数含义
  • h 为 hidden layer dimension size
  • a 为 attention heads 数量
  • s 为 sequence length
  • b 为 microbatch 大小
  • p 为 PP stages 的数量
  • MB/MW 为单次 B 和 W 计算过程各自所需的内存
2、FLOPs 和激活值内存占用分析。只考虑 matmul operations,因其贡献了 transformer 模型主要的运算量,则 transformer layer 中单次 F/B/W 的 FLOPs 和内存占用如下表所示(具体计算过程可参考这篇博客得出结论如下

  • FLOPs 的角度来看,三个不同计算过程的运行时间彼此的关系是 TW < TF < TB,且 TB + TW = 2 TF
  • 激活值占用内存的角度来看,F 为前向计算过程不需要保存激活值;B 过程激活值占用的内存参考论文 Reducing activation recomputation in large transformer models中的第 4.3 节;当 B 过程结束后可以释放一定的内存但仍要保留用于 W 的激活值,因此 W 所需激活值的内存占用要小于 B

3、Bubble size 和 Peak activations memory 分析


四、Heuristic strategy for optimal schedule params

五、Post-validation 策略

1、策略背景:PP 不同 stages 间的同步一般发生在 optimizer step,但 ZB-H 2 不满足这一要求,所以 ZB 通过 post-validation 策略绕开这一限制

2、策略原理

  • 当前迭代:每个 stage 进入 optimizer step 之前,首先从上一个 stage 获取 partially reduced global state,并将获取到的 state 与自身 stage 的 local state 一起传给下一个 stage,直到最后一个 stage 计算得到 fully reduced global state(下图中的 4 号方块);同时每个 stage 对应的 optimizer step 由 partially reduced state 决定;
  • 下轮迭代:warm-up phase 的过程中,将上一次迭代得到的 fully reduced global state(下图中的 5 号方块)会从 last stage 开始依次传回到第一个 stage,该过程中每个 stage 都会检查上一个 stage 的 optimizer step 是否合法,如果需要修正梯度则进行 rollback 操作(红色窄方框)

六、代码走读

待续

评论

此博客中的热门博文

Reasonable Faith:Chap1 How Do I Know Christianity Is True?

《笔记的方法》简单总结

APRE训练计划