通过组合实现分布外泛化:基于Transformer中归纳头的研究

大语言模型中的分布外泛化与组合机制研究

论文背景

近年来,大语言模型(Large Language Models, LLMs)如 GPT-4 在处理新颖任务时表现出惊人的创造力,通常只需少数示例即可解决问题。这些任务要求模型在不同于训练数据的分布上进行泛化,即所谓的“分布外泛化”(Out-of-Distribution Generalization, OOD Generalization)。尽管 LLMs 取得了巨大成功,但其如何实现分布外泛化仍是一个未解之谜。本文旨在通过研究 LLMs 在隐藏规则生成的任务中的表现,特别是通过聚焦于一种称为“归纳头”(Induction Heads, IHs)的组件,揭示分布外泛化与组合机制之间的关系。

本文的研究主要针对 LLMs 在符号推理等任务中的表现,探究其如何在不进行微调的情况下推断出输入提示背后的隐藏规则。通过对训练动态的实证研究,作者发现 LLMs 通过组合两个自注意力层来学习规则,从而实现分布外泛化。此外,作者提出了“共同桥接表示假设”(Common Bridge Representation Hypothesis),认为嵌入(或特征)空间中的共享潜在子空间通过对齐早期和后期层,充当了组合的桥梁。

论文来源

本文由 Jiajun Song、Zhuoyan Xu 和 Yiqiao Zhong 共同撰写,来自北京通用人工智能研究院和威斯康星大学麦迪逊分校。论文于 2025 年 2 月 7 日发表在 PNAS(Proceedings of the National Academy of Sciences) 上,标题为“Out-of-Distribution Generalization via Composition: A Lens through Induction Heads in Transformers”。

研究流程与结果

研究流程

  1. 合成任务实验
    作者首先在一个合成任务上进行了实验,即“复制序列”任务。给定一个包含重复模式的序列(如 [a], [b], [c]),模型需要在接收 [a], [b] 后预测下一个令牌为 [c]。实验使用了两层 Transformer 模型,训练过程中采用了标准的自注意力机制和残差连接。

  2. 训练动态分析
    在训练过程中,作者观察到了两个阶段:弱学习阶段和规则学习阶段。在弱学习阶段,模型仅学习到了输入序列的简单统计特征,无法在分布外数据上泛化。而在规则学习阶段,模型学会了复制规则,并在分布内外数据上都表现良好。

  3. 归纳头的作用
    通过分析训练动态,作者发现归纳头(IHs)在分布外泛化中起到了关键作用。IHs 是一种自注意力头,能够在输入序列中识别重复模式,并预测下一个令牌。实验表明,模型通过组合两个自注意力层,分别处理位置信息和令牌信息,从而实现分布外泛化。

  4. 共同桥接表示假设
    作者进一步提出了共同桥接表示假设,认为多层多头模型中的潜在子空间充当了组合的桥梁。通过对齐早期层和后期层的子空间,模型能够在分布外数据上实现泛化。

  5. 大规模语言模型实验
    为了验证上述假设,作者在多种预训练 LLMs 上进行了广泛的实验,包括 LLaMA、Mistral 和 Falcon 等模型。实验结果表明,IHs 在符号推理、数学推理等任务中起到了关键作用,特别是在分布外数据上。

研究结果

  1. 合成任务结果
    在合成任务中,两层 Transformer 模型表现出了分布外泛化的能力,而单层模型仅实现了弱学习。实验数据表明,模型在规则学习阶段的泛化能力显著提升,特别是在处理较长的重复序列时。

  2. IHs 的实验结果
    在不同任务的实验中,移除 IHs 显著降低了模型在分布外数据上的表现。例如,在符号推理任务中,移除 IHs 后模型的准确率从接近 90% 下降到 50% 以下。

  3. 共同桥接表示假设的验证
    实验结果表明,IHs 和前期注意力头(Previous-Token Heads, PTHs)共享一个潜在子空间。通过对齐这些子空间,模型能够在分布外数据上实现泛化。这一假设通过权重矩阵的投影实验得到了进一步验证。

结论与意义

本文的主要结论是,LLMs 通过组合机制实现分布外泛化,而 IHs 和 PTHs 在组合过程中起到了关键作用。共同桥接表示假设为理解 LLMs 如何学习规则并在新颖任务中实现泛化提供了新的视角。

科学价值

  1. 揭示了泛化机制
    本文通过实证研究揭示了 LLMs 如何在分布外数据上实现泛化,填补了这一领域的研究空白。

  2. 提出了新假设
    共同桥接表示假设为理解 LLMs 的组合机制提供了新的理论框架,有助于进一步研究模型的内部结构。

  3. 应用价值
    本文的研究成果可以为改进 LLMs 的训练方法和模型设计提供指导,特别是如何提升模型在新颖任务中的表现。

研究亮点

  1. 新颖的研究视角
    本文通过聚焦于 IHs,揭示了 LLMs 分布外泛化的内部机制,这是此前研究较少涉及的领域。

  2. 广泛的实验验证
    本文不仅在合成任务上进行了实验,还在多种大规模 LLMs 上进行了广泛的验证,增强了结论的普适性。

  3. 理论创新
    共同桥接表示假设为理解 LLMs 如何通过组合机制实现泛化提供了新的理论视角,具有重要的学术价值。

其他有价值的信息

本文的代码和数据已在 GitHub 上开源,链接为:https://github.com/jiajunsong629/ood-generalization-via-composition。这为其他研究者复现和扩展本文的研究提供了便利。

总结

本文通过对 LLMs 中分布外泛化机制的深入研究,揭示了组合机制在模型学习规则和实现泛化中的关键作用。这不仅深化了我们对 LLMs 内部结构的理解,也为未来的模型设计和优化提供了重要的理论支持。