⏶15
StreamBP:LLM 长序列训练的内存高效精确反向传播
发表
由
Mengqi Li 提交
作者: Qijun Luo,
Mengqi Li, Lei Zhao, Xiao Li
摘要
在长序列数据上训练语言模型是提高模型在复杂任务(例如,长链推理)上能力的一项严苛要求。然而,随着序列长度的增加,即使应用了梯度检查点技术,在反向传播(BP)过程中存储激活值的内存成本也变得巨大。为了解决这个挑战,我们提出了一种内存高效且精确的 BP 方法,称为 StreamBP,它以层级方式沿着序列维度对链式法则进行线性分解,显著降低了激活值和 logits 的内存成本。所提出的方法适用于常见的优化目标,如 SFT、GRPO 和 DPO。从实现角度来看,StreamBP 通过利用语言模型的因果结构,实现了更少的计算 FLOPs 和更快的 BP 速度。与梯度检查点相比,StreamBP 将 BP 的最大序列长度扩大了 2.8-5.5 倍,同时使用相当甚至更少的 BP 时间。需要注意的是,StreamBP 的序列长度缩放能力可以直接转换为批次大小缩放,以加速训练。我们进一步开发了一种通信高效的分布式 StreamBP,以有效支持多 GPU 训练并扩大其适用性。我们的代码可以轻松集成到任何 Transformer 模型的训练管道中,并可在 https://github.com/Ledzy/StreamBP 获取。
项目页面:https://github.com/Ledzy/StreamBP
StreamBP 大幅降低了激活值的内存开销,并将最大序列长度扩展到比梯度检查点大 2.8-5.5 倍,同时使用相似甚至更少的 BP 时间。另一方面,这种序列长度扩展能力可以直接转换为批处理大小扩展,以实现更快的训练,因为内存开销与序列长度呈线性关系。
下图显示了 StreamBP 与传统方法和梯度检查点相比的 BP 内存消耗。