GaLore:通过梯度低秩投影实现内存高效的LLM训练

03月06日发表
04月12日由 AKAK 提交
作者: Jiawei ZhaoJiawei Zhao, Zhenyu ZhangZhenyu Zhang, Beidi ChenBeidi Chen, Zhangyang Wang, Anima AnandkumarAnima Anandkumar, Yuandong TianYuandong Tian

摘要

训练大型语言模型 (LLM) 带来了巨大的内存挑战,这主要是由于权重和优化器状态的规模不断增长。常见的内存减少方法(例如低秩自适应 (LoRA))在每层中向冻结的预训练权重添加可训练的低秩矩阵,从而减少可训练参数和优化器状态。然而,这种方法通常在预训练和微调阶段都表现不佳,因为它们将参数搜索限制在低秩子空间并改变了训练动态,并且进一步地,可能需要全秩热启动。在这项工作中,我们提出了梯度低秩投影 (GaLore),这是一种训练策略,允许全参数学习,但比常见的低秩自适应方法(如 LoRA)更节省内存。我们的方法在优化器状态中最多可减少 65.5% 的内存使用量,同时保持了在 LLaMA 1B 和 7B 架构上使用 C4 数据集进行预训练(最多 19.7B 个 token)以及在 GLUE 任务上微调 RoBERTa 的效率和性能。我们的 8 位 GaLore 进一步将优化器内存减少高达 82.5%,并将总训练内存减少 63.3%,与 BF16 基线相比。值得注意的是,我们首次证明了在具有 24GB 内存的消费级 GPU(例如,NVIDIA RTX 4090)上预训练 7B 模型的可行性,而无需模型并行、检查点或卸载策略。

评论

Mobin ChowdhuryMobin Chowdhury

我们需要官方 github 代码 pls 和 hf 集成... 多么酷的项目

Maria TangMaria Tang

我也希望看到尽可能多的源代码,非常感谢

Alok ShuklaAlok Shukla
此评论已隐藏。
Lee GaoLee Gao

一些想法/点子,我不知道它们是否有道理:

  1. r 可以是一个超参数,是否可以将其设置为奇异值的阈值?甚至可以使用随机矩阵理论来找到信号噪声的频谱阈值?

  2. T 可以是一个超参数,是否可以测量 Pt^T Gt Q_t 有多“对角线”?我相信论文中的直觉是,我们希望定期“刷新”与全秩支持相对应的主方向,以防它们随时间漂移。据我理解,投影梯度最初只是奇异值的对角线,并且它会随着时间的推移而偏离该结构(我做了一个很大的假设,即这种漂移是渐进的,并且与 P,Q 作为梯度更新的主方向仍然有多好成反比)。似乎您可以某种程度上量化该漂移,并使用它来驱动 P,Q 是否仍然是梯度更新的良好主方向。

Lee GaoLee Gao

对于图 1。

6evoFc9NWrrRCjH.png

您能否也包括 8bit-Adam + 每层权重更新但在梯度更新中不进行秩缩减会对内存使用产生什么影响?根据 Lomo 论文 / https://arxiv.org/abs/2306.09782,似乎它也会显着减少内存使用的浅绿色部分,因为梯度在每层立即被消耗+丢弃?

Yuandong TianYuandong Tian
论文作者

感谢您的评论!我们在这里有第三方评估:https://github.com/jiaweizzhao/GaLore/issues/6。仅 GaLore(不进行每层权重更新)就具有与每层权重更新相当的内存减少量。它们是正交技术。通过将它们组合在一起,您可以在 24G 内存(例如 4090)内运行 7B 预训练。

Lee GaoLee Gao

谢谢!这会是表 8 中的数字吗?

南栖南栖

非常强大的技术。

Derek ThomasDerek Thomas

令人难以置信的论文!我很高兴看到它随着时间的推移如何发展。我已经成为 LoRA 小更新足迹的粉丝,尤其是在服务方面。但对于某些用例,我可以看到希望获得更高的性能。

我也很想看到:

  • 跨各种任务/指标的下游任务性能

  • 常见用例的内存场景。我从 GaLore 与 LoRA 或其他方法中获得了多少好处,或者它们是否都非常相似。

Julien BLANCHONJulien BLANCHON
GaLore:使用内存高效的梯度投影彻底改变 LLM 训练

https://cdn-uploads.huggingface.co/production/uploads/6186ddf6a7717cb375090c01/jYTt_MsOyiFW-R1qi1dTE.mp4

链接 🔗:

👉 订阅: https://www.youtube.com/@Arxflix

👉 Twitter: https://x.com/arxflix

👉 LMNT (合作伙伴): https://lmnt.com/

作者:Arxflix

9t4iCUHx_400x400-1.jpg

Pedro SandovalPedro Sandoval

希望澄清图 1:这里的批量大小、序列长度和词汇表大小是多少?因为我预计激活会占用更多空间...

  • 根据图 1 的标题,批量大小似乎为 256

  • 根据脚注 1,序列长度似乎为 2048

  • 根据 config from repo,词汇表大小为 32000

  • 根据脚注 2,使用 bf16,因此每个浮点数为 2 个字节

因此,只有模型的 logits 应该占用 256 * 2048 * 32000 * 2 字节或 31.25 GB。图 1 中哪里需要这个内存?

谢谢!

Soeren Moeller ChristensenSoeren Moeller Christensen

似乎可以将 Lora 和 galore(以及它们的量化对应物 qlora 和 qgalore)结合起来,通过在 lora 矩阵 A 和 B 的梯度上使用 galore 来进一步减少内存占用。有人尝试过对此进行实验吗?我在论文中找不到相关信息,因为他们主要将他们的工作视为 lora 的完全替代方案。