用于遗忘Transformer的自适应计算剪枝

04月09日发表
04月16日由 Zhixuan LinZhixuan Lin 提交
作者: Zhixuan LinZhixuan Lin, Johan Obando-Ceron, Owen HeXu Owen He, Aaron Courville

摘要

最近提出的遗忘 Transformer (FoX) 将遗忘门合并到 softmax 注意力中,并且与基于标准 RoPE 的 Transformer 相比,表现出始终如一的更好或相当的性能。值得注意的是,FoX 中的许多注意力头倾向于快速遗忘,导致它们在每个时间步的输出主要依赖于局部上下文。基于这一观察,我们提出了 FoX 的自适应计算剪枝 (ACP),这是一种动态剪枝计算的方法,涉及被遗忘门强烈衰减的输入-输出依赖关系。这是通过使用动态设置的剪枝阈值来实现的,该阈值确保剪枝的注意力权重保持可忽略不计。我们将 ACP 应用于使用 FoX 进行的语言模型预训练,并表明它可以始终如一地将 softmax 注意力中的 FLOP 数量减少约 70%,跨不同的模型大小和上下文长度,从而使训练吞吐量大致提高 10% 到 35%。此外,更长的上下文长度会产生更大的计算节省。所有这些速度提升都是在没有任何性能下降的情况下实现的。我们还进行了多项分析,以更深入地了解我们的方法,例如检查剪枝模式和分析不同注意力头之间的 FLOP 节省分布。我们的代码可在 https://github.com/zhixuan-lin/arctic-fox 上找到。
查看 arXiv 页面查看 PDF

评论

Zhixuan LinZhixuan Lin
论文作者
论文提交者

此方法旨在加速先前提出的 遗忘转换器 (FoX),而不会降低任何性能。FoX 向转换器添加了一个遗忘门,由此产生的注意力机制也可以看作是 ALiBi 的数据相关且可学习的版本,如下所示:

Screenshot 2025-04-15 at 22.25.06.png

自适应计算剪枝 (ACP) 的核心思想很简单:我们不需要在我们忘记的事情上浪费计算。具体来说,如果 $D{ij}$ 远低于零(例如,-1000),那么项 $\exp(qi^\top kj + D{ij})$ 在归一化后很可能为零,因此可以剪枝此项中涉及的任何计算。由于矩阵 $D$ 的特殊结构,这可以通过识别 FlashAttention 计算网格上的剪枝边界并仅在剪枝边界的右侧执行计算来完成:

acp-graph-full.png

结果摘要:

  • 在这项工作中,我们专注于预训练,尽管原则上它也可以在推理期间使用(即,预填充和解码)

  • ACP 持续剪枝大约 70% 的注意力 FLOP,从而使训练吞吐量提高大约 10%-35%,具体取决于模型大小和上下文长度。

  • 所有速度提升都是在没有任何性能下降的情况下实现的。这是因为我们动态设置 $D_{ij}$ 的阈值,以确保剪枝的总注意力权重受到小数字的限制(实际上,受到 $e^{-10} < 0.00005$ 的限制)。

代码: https://github.com/zhixuan-lin/arctic-fox。我们未来会有更多结果发布。敬请期待!