⏶3
用于遗忘Transformer的自适应计算剪枝
04月09日发表
04月16日由
Zhixuan Lin 提交
作者:
Zhixuan Lin, Johan Obando-Ceron,
Xu Owen He, Aaron Courville
摘要
最近提出的遗忘 Transformer (FoX) 将遗忘门合并到 softmax 注意力中,并且与基于标准 RoPE 的 Transformer 相比,表现出始终如一的更好或相当的性能。值得注意的是,FoX 中的许多注意力头倾向于快速遗忘,导致它们在每个时间步的输出主要依赖于局部上下文。基于这一观察,我们提出了 FoX 的自适应计算剪枝 (ACP),这是一种动态剪枝计算的方法,涉及被遗忘门强烈衰减的输入-输出依赖关系。这是通过使用动态设置的剪枝阈值来实现的,该阈值确保剪枝的注意力权重保持可忽略不计。我们将 ACP 应用于使用 FoX 进行的语言模型预训练,并表明它可以始终如一地将 softmax 注意力中的 FLOP 数量减少约 70%,跨不同的模型大小和上下文长度,从而使训练吞吐量大致提高 10% 到 35%。此外,更长的上下文长度会产生更大的计算节省。所有这些速度提升都是在没有任何性能下降的情况下实现的。我们还进行了多项分析,以更深入地了解我们的方法,例如检查剪枝模式和分析不同注意力头之间的 FLOP 节省分布。我们的代码可在 https://github.com/zhixuan-lin/arctic-fox 上找到。
此方法旨在加速先前提出的 遗忘转换器 (FoX),而不会降低任何性能。FoX 向转换器添加了一个遗忘门,由此产生的注意力机制也可以看作是 ALiBi 的数据相关且可学习的版本,如下所示:
自适应计算剪枝 (ACP) 的核心思想很简单:我们不需要在我们忘记的事情上浪费计算。具体来说,如果 $D{ij}$ 远低于零(例如,-1000),那么项 $\exp(qi^\top kj + D{ij})$ 在归一化后很可能为零,因此可以剪枝此项中涉及的任何计算。由于矩阵 $D$ 的特殊结构,这可以通过识别 FlashAttention 计算网格上的剪枝边界并仅在剪枝边界的右侧执行计算来完成:
结果摘要:
在这项工作中,我们专注于预训练,尽管原则上它也可以在推理期间使用(即,预填充和解码)
ACP 持续剪枝大约 70% 的注意力 FLOP,从而使训练吞吐量提高大约 10%-35%,具体取决于模型大小和上下文长度。
所有速度提升都是在没有任何性能下降的情况下实现的。这是因为我们动态设置 $D_{ij}$ 的阈值,以确保剪枝的总注意力权重受到小数字的限制(实际上,受到 $e^{-10} < 0.00005$ 的限制)。
代码: https://github.com/zhixuan-lin/arctic-fox。我们未来会有更多结果发布。敬请期待!