⏶35
利用可训练稀疏注意力实现更快的视频扩散
发表
由
Zhang Peiyuan 提交

作者:
Peiyuan Zhang,
Haofeng Huang,
Yongqi Chen, Will Lin,
Zhengzhong Liu, Ion Stoica,
Eric P. Xing, Hao Zhang





摘要
扩展视频扩散 Transformer (DiTs) 受限于其二次三维注意力,即使绝大多数注意力权重集中在少数位置。我们将这一观察转化为 VSA,一种可训练、硬件高效的稀疏注意力机制,它在训练和推理阶段取代了全注意力。在 VSA 中,一个轻量级的粗粒度阶段将 tokens 汇集到 tile 中并识别高权重的关键 tokens;一个细粒度阶段仅在这些 tile 内部计算 token 级别的注意力,遵循块计算布局以确保硬件效率。这产生了一个单一的可微分核,可以端到端训练,无需事后性能分析,并保持 FlashAttention3 MFU 的 85%。我们通过预训练 60M 至 1.4B 参数的 DiTs 模型,进行了一系列广泛的消融研究和缩放法则实验。VSA 达到了一个帕累托最优解,在扩散损失没有下降的情况下,将训练 FLOPS 减少了 2.53 倍。对开源的 Wan-2.1 模型进行改进,将注意力计算时间加快了 6 倍,并将端到端生成时间从 31 秒降至 18 秒,同时保持相当的质量。这些结果证明了可训练稀疏注意力是全注意力的一种实用替代方案,也是进一步扩展视频扩散模型的关键推动力。
扩展视频扩散Transformer(DiTs)受到其二次3D注意力的限制,尽管大部分注意力集中在少数位置上。我们将这一观察转化为 VSA,一种可训练、硬件高效的稀疏注意力,它在训练和推理时都取代了全注意力。在 VSA 中,轻量级粗粒度阶段将 token 池化到 tile 中并识别高权重 \emph{关键 token};细粒度阶段仅在遵循块计算布局的 tile 内计算 token 级别的注意力,以确保硬件效率。这产生了一个可微分的单个内核,可以端到端训练,无需事后性能分析,并保持 FlashAttention3 85% 的 MFU。我们通过预训练参数范围从 60M 到 1.4B 的 DiTs 模型,进行了大量的消融研究和扩展定律实验。VSA 在不降低扩散损失的情况下,达到了帕累托最优,训练 FLOPS 降低了 2.53 倍。改造开源 Wan-2.1 模型后,注意力时间加速 6 倍,端到端生成时间从 31 秒缩短至 18 秒,同时保持质量相当。这些结果确立了可训练稀疏注意力作为全注意力的实用替代方案,以及进一步扩展视频扩散模型的关键促成因素。