对角批处理解锁循环记忆 Transformer 在长上下文中的并行性

发表
Yury KuratovYury Kuratov 提交
作者: Sivtsov DanilDanil Sivtsov, Ivan Rodkin, Gleb Kuzmin, Yury KuratovYuri Kuratov, Ivan Oseledets

摘要

Transformer模型由于其二次时间复杂度和线性内存复杂度,在长上下文推理方面面临挑战。循环记忆Transformer(RMT)通过将渐近成本降低到线性时间和常数内存使用量,提供了一种解决方案。然而,它们的内存更新机制导致顺序执行,从而造成性能瓶颈。我们引入了对角批处理(Diagonal Batching),这是一种调度方案,它在RMT中实现了跨段并行化,同时保持了精确的循环。这种方法消除了顺序约束,即使对于单个长上下文输入,无需复杂的批处理和流水线技术也能实现高效的GPU推理。因为该技术纯粹是一种运行时计算重排序,所以现有的RMT模型无需重新训练即可采用它。应用于LLaMA-1B ARMT模型,对角批处理在131,072-token序列上比标准的全注意力LLaMA-1B实现了3.3倍的加速,比顺序RMT实现加速了1.8倍。通过消除顺序瓶颈,对角批处理降低了推理成本和延迟,从而使RMT成为现实世界长上下文应用的实用解决方案。
查看 arXiv 页面查看 PDF

评论

Yury KuratovYury Kuratov
论文作者
论文提交者

GitHub:https://github.com/svtdanny/diagonal-batching

Sivtsov DanilSivtsov Danil
论文作者

Sivtsov DanilSivtsov Danil
论文作者

左图:

并行RMT推广了一系列具有层级记忆的模型。每个层维护自己的记忆状态,并将其横向传递给下一个段中的相同层。这消除了层间记忆流,但仍需要在每个层内按顺序处理段,从而创建了层级循环。

右图:

对角批处理将层(行)和段(列)的2D网格重新排列成独立的“对角线”(同色块)。这允许在一个对角线上的所有操作(最多N_Layers)在GPU上并行执行,从而消除了顺序瓶颈,同时保留了所有层级循环。