用于语言模型的模型链学习

发表
Kaitao SongKaitao Song 提交
作者: Kaitao SongKaitao Song, Xiaohua Wang, Xu TanXu Tan, Huiqiang JiangHuiqiang Jiang, ZhangChengruidong Zhang, Yongliang ShenYongliang Shen, CenCen LU, Zihao LiZihao Li, Zifan Song, SHAN CAIHUACaihua Shan, Yansen WangYansen Wang, Kan Ren, zheng xiaoqingXiaoqing Zheng, Tao Qin, Yuqing Yang, Dongsheng Li, Lili Qiu

摘要

在本文中,我们提出了一种新颖的学习范式,称为 Chain-of-Model (CoM),它将因果关系以链式结构融入到每一层的隐藏状态中,从而在模型训练中引入了极大的扩展效率,并在部署中提供了推理灵活性。我们引入了 Chain-of-Representation (CoR) 的概念,它将每一层的隐藏状态表示为隐藏维度级别的多个子表示(即链)的组合。在每一层中,输出表示中的每个链只能看到输入表示中其之前的所有链。因此,基于 CoM 框架构建的模型可以通过基于先前的模型(即链)增加链的数量来逐步扩展模型大小,并通过使用不同的链数量提供不同大小的多个子模型用于弹性推理。基于这一原则,我们设计了 Chain-of-Language-Model (CoLM),它将 CoM 的思想融入到 Transformer 架构的每一层中。基于 CoLM,我们通过引入 KV 共享机制进一步推出了 CoLM-Air,该机制在第一个链中计算所有键和值,然后跨所有链共享。这种设计展示了额外的可扩展性,例如支持无缝语言模型切换、预填充加速等。实验结果表明,我们的 CoLM 系列可以实现与标准 Transformer 相当的性能,同时提供了更大的灵活性,例如通过渐进式扩展提高训练效率,以及提供多种不同大小的模型用于弹性推理,为构建语言模型开辟了一条新途径。我们的代码将来会在 https://github.com/microsoft/CoLM 发布。
查看 arXiv 页面查看 PDF

评论

Kaitao SongKaitao Song
论文作者
论文提交者

在本文中,我们提出了一种新颖的学习范式,称为模型链(Chain-of-Model,CoM),它以链式风格将因果关系融入到每一层的隐藏状态中,从而在模型训练中带来了巨大的缩放效率,并在部署中提供了灵活的推理能力。我们引入了表示链(Chain-of-Representation,CoR)的概念,它将每一层的隐藏状态表示为隐藏维度层面上多个子表示(即链)的组合。在每一层中,输出表示的每个链只能看到输入表示中它之前的所有链。因此,基于 CoM 框架构建的模型可以通过基于先前的模型(即链)增加链的数量来逐步扩大模型规模,并通过使用不同数量的链来提供多种不同大小的子模型以进行弹性推理。基于这一原则,我们设计了语言模型链(Chain-of-Language-Model,CoLM),它将 CoM 的思想融入到 Transformer 架构的每一层中。在 CoLM 的基础上,我们进一步通过引入 KV 共享机制提出了 CoLM-Air,该机制在第一个链中计算所有键和值,然后将其共享给所有链。这种设计展示了额外的可扩展性,例如支持无缝的语言模型切换、预填充加速等。实验结果表明,我们的 CoLM 系列模型可以达到与标准 Transformer 相当的性能,同时提供了更大的灵活性,例如通过渐进式缩放来提高训练效率,并提供多种不同大小的模型以进行弹性推理,为构建语言模型开辟了新的途径。我们的代码未来将在以下地址发布:https://github.com/microsoft/CoLM。

YJYJ

移动学习音频概述:https://youtu.be/YO0Cxeclywg

ChatGPT Image May 20, 2025, 10_16_25 AM.png

FBLFBL

如何连接两个不同隐藏层尺寸?不同模型通常有不同的维度大小……连接到同一个,我不太确定我们正在看什么

Kaitao SongKaitao Song
论文作者
论文提交者

实际上,你可以选择不同的链大小。正如第3.1节所述,我们引入超参数 C = {c1, c2, ..., cn} 来确定每个链的大小,因此每个链的大小由 ci / sum(C) * D 计算得出。在我们的实验中,由于资源限制,我们选择相同的链大小是为了稍微提高训练效率。使用相同的链允许我们在某些算子(例如,归一化)中使用一些组操作,从而略微提高训练效率。此外,你可以通过附录中的算法1(简单的PyTorch实现)来确定不同的大小。但由于更多的数据访问和 all-reduce 操作,它会比较慢。因此,我们使用 Triton 设计了一个块大小稀疏核。因此,我们期望每个链的大小应该是块大小的倍数,其中 block_size 是2的幂(至少64或更大)。

Kaitao SongKaitao Song
论文作者
论文提交者

正如表10和表11所示,我们也尝试了 chain = {8, 8, 16} 和 {8, 24},这意味着链大小分别为 {512, 512, 1024} 和 {512, 1536}。

Vadim KataevVadim Kataev

您尝试过展开 W 并使用密集 GEMM 吗?这会浪费一些空间,但可能会更快。

Kaitao SongKaitao Song
论文作者
论文提交者

> 您尝试过展开 W 并使用密集 GEMM 吗?这会浪费一些空间,但可能会更快。

感谢您的建议。我们已经考虑过,但令人尴尬的情况是我们的 GPU 只有 40GB A100,我不得不应用许多内存高效技术。但我们的 Triton 实现与附录中表 15 所示的标准 MLP 相比也稍快一些。但我承认它可能还会更好。