⏶7
FLARE:快速低秩注意力路由引擎
发表
由
Vedant Puri 提交

作者:
Vedant Puri, Aditya Joglekar, Kevin Ferguson, Yu-hsuan Chen, Yongjie Jessica Zhang, Levent Burak Kara

摘要
自注意力机制的二次复杂度限制了其在大型非结构化网格上的应用和可扩展性。我们介绍了快速低秩注意力路由引擎(FLARE),一种线性复杂度的自注意力机制,它通过固定长度的潜在序列进行注意力路由。每个注意力头通过使用可学习的查询令牌将输入序列投影到固定长度的 M (M ≪ N) 个令牌的潜在序列上,从而在 N 个令牌之间执行全局通信。通过将注意力路由到一个瓶颈序列,FLARE 学习了一种低秩的注意力形式,其成本为 O(NM)。FLARE 不仅可以扩展到前所未有的问题规模,而且在各种基准测试中,其精度优于最先进的神经偏微分程 (PDE) 代理模型。我们还发布了一个新的增材制造数据集,以促进进一步的研究。我们的代码可在 https://github.com/vpuri3/FLARE.py 上获取。


FLARE 是一种新颖的 token 混合层,通过利用低秩性来规避自注意力机制的二次成本。FLARE 的论点是,将输入序列投影到更短的潜在序列,然后再反向投影到原始序列长度,等同于构建一个秩至多等于潜在 token 数量的低秩注意力形式(见下图)。
此外,我们认为多个同时进行的低秩投影可以共同捕捉完整的注意力模式。我们的设计为每个 head 分配了不同的潜在 token 切片,从而为每个 head 产生了不同的投影矩阵。这使得每个 head 能够学习独立的注意力关系,开辟了一个关键的扩展和探索方向,其中每个 head 都可以专注于不同的路由模式。
FLARE 完全由标准的融合注意力原语构建,确保了高 GPU 利用率和易于集成到现有 Transformer 架构中。通过用低秩投影和重构替换完整的自注意力,FLARE 在点数量上实现了线性复杂度(见下图)。因此,FLARE 使得在单 GPU 上对具有一百万个点的非结构化网格进行端到端训练成为可能——这是基于 Transformer 的 PDE 代理模型所能达到的最大规模。