⏶70
SageAttention3:微缩放 FP4 注意力用于推理以及对 8 比特训练的探索
发表
由
Jintao Zhang 提交

作者:
Jintao Zhang, Jia Wei,
Pengle Zhang, Xiaoming Xu,
Haofeng Huang,
Haoxu Wang, Kai Jiang, Jun Zhu,
Jianfei Chen



摘要
由于其二次时间复杂度,注意力的效率至关重要。我们通过两项关键贡献来提升注意力的效率:首先,我们利用 Blackwell GPU 中新的 FP4 Tensor Cores 来加速注意力计算。我们的实现在 RTX5090 上达到了 1038 TOPS,比 RTX5090 上最快的 FlashAttention 加速了 5 倍。实验表明,我们的 FP4 注意力可以以即插即用的方式加速各种模型的推理。其次,我们开创性地将低比特注意力应用于训练任务。现有的低比特注意力工作,如 FlashAttention3 和 SageAttention,仅专注于推理。然而,训练大型模型的效率也同样重要。为了探索低比特注意力是否能有效地应用于训练任务,我们设计了一种针对前向和后向传播的准确高效的 8 比特注意力。实验表明,8 比特注意力在微调任务中实现了无损性能,但在预训练任务中收敛速度较慢。代码将在 <a href="https://github.com/thu-ml/SageAttention">https://github.com/thu-ml/SageAttention</a> 提供。

SageAttention3:用于推理的微缩放 FP4 Attention,提速5倍;以及用于训练的8位Attention。代码将在 https://github.com/thu-ml/SageAttention 提供。