利用可训练稀疏注意力实现更快的视频扩散

发表
Zhang PeiyuanZhang Peiyuan 提交
作者: Zhang PeiyuanPeiyuan Zhang, haofeng huangHaofeng Huang, Yongqi ChenYongqi Chen, Will Lin, Zhengzhong LiuZhengzhong Liu, Ion Stoica, Eric XingEric P. Xing, Hao Zhang

摘要

扩展视频扩散 Transformer (DiTs) 受限于其二次三维注意力,即使绝大多数注意力权重集中在少数位置。我们将这一观察转化为 VSA,一种可训练、硬件高效的稀疏注意力机制,它在训练和推理阶段取代了全注意力。在 VSA 中,一个轻量级的粗粒度阶段将 tokens 汇集到 tile 中并识别高权重的关键 tokens;一个细粒度阶段仅在这些 tile 内部计算 token 级别的注意力,遵循块计算布局以确保硬件效率。这产生了一个单一的可微分核,可以端到端训练,无需事后性能分析,并保持 FlashAttention3 MFU 的 85%。我们通过预训练 60M 至 1.4B 参数的 DiTs 模型,进行了一系列广泛的消融研究和缩放法则实验。VSA 达到了一个帕累托最优解,在扩散损失没有下降的情况下,将训练 FLOPS 减少了 2.53 倍。对开源的 Wan-2.1 模型进行改进,将注意力计算时间加快了 6 倍,并将端到端生成时间从 31 秒降至 18 秒,同时保持相当的质量。这些结果证明了可训练稀疏注意力是全注意力的一种实用替代方案,也是进一步扩展视频扩散模型的关键推动力。
查看 arXiv 页面查看 PDF

评论

Zhang PeiyuanZhang Peiyuan
论文作者
论文提交者

扩展视频扩散Transformer(DiTs)受到其二次3D注意力的限制,尽管大部分注意力集中在少数位置上。我们将这一观察转化为 VSA,一种可训练、硬件高效的稀疏注意力,它在训练和推理时都取代了全注意力。在 VSA 中,轻量级粗粒度阶段将 token 池化到 tile 中并识别高权重 \emph{关键 token};细粒度阶段仅在遵循块计算布局的 tile 内计算 token 级别的注意力,以确保硬件效率。这产生了一个可微分的单个内核,可以端到端训练,无需事后性能分析,并保持 FlashAttention3 85% 的 MFU。我们通过预训练参数范围从 60M 到 1.4B 的 DiTs 模型,进行了大量的消融研究和扩展定律实验。VSA 在不降低扩散损失的情况下,达到了帕累托最优,训练 FLOPS 降低了 2.53 倍。改造开源 Wan-2.1 模型后,注意力时间加速 6 倍,端到端生成时间从 31 秒缩短至 18 秒,同时保持质量相当。这些结果确立了可训练稀疏注意力作为全注意力的实用替代方案,以及进一步扩展视频扩散模型的关键促成因素。

Yuxiong WuYuxiong Wu

有计划发布代码吗?

Yongqi ChenYongqi Chen
论文作者

我们将在 https://github.com/hao-ai-lab/FastVideo 即将发布代码!