本文是由Anshu Mann、Mohd Abbas Zaidi、Akhil Kedia、Jinwoo Ahn、Taehwak Kwon、Kangwook Lee、Haejun Lee、Joohyung Lee(均来自三星研究院首尔研究所)共同完成的研究论文。该论文发表在2025年7月27日至8月1日举行的第63届计算语言学协会年会(ACL)的会议论文集中(第1卷:长论文)。论文题为《稀疏Logit采样:加速LLM中的知识蒸馏》。
该研究隶属于自然语言处理与人工智能领域,具体关注大型语言模型的高效训练技术。研究的核心动机是解决知识蒸馏在LLM预训练阶段应用时的巨大存储成本问题。在知识蒸馏中,通常需要将从大型教师模型(teacher model)输出的完整概率分布(logits)传授给小型学生模型(student model)。然而,对于现代LLM巨大的词表规模(例如数万个token),为海量训练数据(如万亿token)存储完整的教师输出分布所需的存储空间是天文数字(论文中举例,为1万亿token存储LLaMA-3模型的完整分布需要128PB),这使得传统的离线知识蒸馏(即预先计算并缓存教师logits)在预训练阶段变得不可行。
此前,一些稀疏知识蒸馏方法被提出,例如仅缓存教师概率分布中概率最高的前k个token(top-k KD)。然而,这些方法在实践中常导致学生模型性能下降和校准变差。本研究的背景知识在于,作者通过理论证明与实证分析,揭示了top-k方法性能不佳的根本原因:首先,top-k方法是对教师概率分布的有偏估计,它会放大前k个token的概率,导致学生模型过度自信;其次,它完全丢弃了教师分布“长尾”部分的信息,而这些信息对于模型学习至关重要。基于此,本研究旨在提出一种新的稀疏知识蒸馏方法,能够在保证性能接近完整蒸馏的前提下,将需要存储的logit数量降到极低水平,从而显著降低存储开销并加速训练。
本研究的工作流程严谨且详细,包含问题分析、方法提出、实验验证与结果分析等多个环节。
首先,研究团队对现有稀疏知识蒸馏方法(特别是top-k KD)进行了深入的理论分析与初步实验验证。他们通过理论推导证明了在使用KL散度损失时,top-k KD会导致学生模型的梯度更新公式发生改变,使得非top-k token的概率被强制推向零,而top-k token的概率则学习到一个按比例放大的教师目标,这直接导致了学生模型的预测过度自信和校准误差。为了直观展示这一偏差,研究团队构造了服从齐夫分布的合成玩具分布,并进行了可视化(图2a)。同时,他们在一个简单的合成分类任务和CIFAR-100图像分类任务上训练模型,均观察到top-k KD导致模型过度自信,而完整蒸馏和交叉熵损失训练则能保持良好校准(图2b, 2c)。在LLM预训练实验中,他们训练了300M参数的Llama风格学生模型,使用一个3B参数的教师模型。结果(表1)明确显示,仅使用少量(<25个)top-k token时,学生模型的损失甚至比仅用交叉熵损失(CE)训练更差;即使使用到300个token,性能也仅能达到完整蒸馏的77%。并且,模型校准误差随着k值的减小而显著恶化。这验证了top-k方法存在的两个根本问题:有偏估计和尾部信息缺失。
其次,研究团队尝试了几种针对top-k KD缺陷的经验性解决方案,并评估了其效果。这些方法包括:1) 标签平滑:将残余概率(1减去top-k概率和)均匀分配给所有词汇。这种方法虽然改善了校准,但严重损害了性能(表2),因为现实中的token分布远非均匀。2) 虚拟token:创建一个“虚拟token”来吸收非top-k token的总概率。这种方法改善了校准和性能,但仍不及完整蒸馏,说明对尾部进行显式监督是必要的。3) 朴素修正:将残余概率分配给真实目标token(ground truth)。这种方法需要存储多达100个token才能达到接近完整蒸馏的性能(表2)。这些尝试表明,上述方法要么是有偏估计量,要么缺乏足够的尾部监督,无法完美解决问题。
基于对问题的深刻理解和重要性采样理论的启发,研究团队提出了名为“随机采样知识蒸馏”的新方法。该方法的动机是利用重要性采样从教师分布中获取无偏估计。具体工作流程如下:对于每个训练token,给定教师输出的完整概率分布 t_full。传统top-k是截断分布,而新方法是从该分布中随机采样。他们设定一个提案分布 q,形式为 t_full^τ(τ为采样温度)。通过从q中抽取固定轮数(n)的样本(允许重复),每个被采样token i会被赋予一个似然比权重 t_i / q_i。在τ=1的简化情况下,q就等于原分布,似然比为1。最终,某个token i在子采样后的目标概率 t_s[i] 被计算为其在n次采样中出现次数(c_i)除以n(即 c_i / n)。这个分布是高度稀疏的,非零项的数量最多为n,实践中远少于n。随后,使用前向KL散度损失在非零的 t_s 和学生模型的预测分布 p 之间进行计算。当τ=1时,该损失也可以被视作每个被采样token与学生预测之间的交叉熵损失之和。这个子采样后的教师分布 t_s 可以被缓存到磁盘,并在多次实验中重复使用。
在提出的新方法之后,研究团队对其进行了全面的分析。他们通过理论证明了随机采样KD提供了对教师概率分布的无偏估计,并且在期望上保持了与完整KD相同的梯度。在合成任务和LLM预训练中,该方法都实现了几乎完美的校准,与完整KD相当(图2b, 2c, 3a)。他们还通过实验测量了梯度相似性(表3),发现随机采样方法(仅用12个独特token)产生的梯度与完整KD的梯度在角度和范数上都极为接近(余弦相似度0.998),而top-k方法即使在300个token下仍存在显著差异。研究也探讨了不同提案分布(τ值)的影响,发现τ在0.8到1.2之间时性能稳定且接近最优,因此最终选择τ=1以简化流程。此外,研究比较了训练速度与存储开销(表4, 5)。与完整KD相比,新方法的缓存实现速度快了1.7到2.6倍,而存储开销仅比普通CE训练略高约10%。在存储方面,对于1000亿训练token,完整KD需要10PB,top-300需要90TB,而新方法仅需12个独特token,额外存储仅需3.6TB,比top-300减少了25倍。
随后,研究团队在多组实验中对所提方法进行了系统性验证。评估指标包括预训练数据上的语言建模损失、预期校准误差、在教师模型推测解码中的接受率、零样本自然语言理解分数以及指令微调前后的零样本自然语言生成分数。首先,在小规模设置(3B教师→300M学生,训练100亿token)下,随机采样KD仅使用约12个独特token就达到了与完整KD非常接近的性能(表6),在语言建模损失、校准误差和推测解码接受率上表现相当,甚至在零样本NLU分数上略有超越。当将训练扩展到1600亿token(远超过Chinchilla最优值)时,结果依然保持一致(表7)。其次,在大规模公开设置(Llama-3-8B教师→3B学生,在FineWeb-Edu数据集上训练1000亿token)下,新方法(12个token)同样取得了与完整KD相似的损失、校准和推测解码接受率,并且在零样本下游任务和指令跟随性能上表现更优(表8)。使用LLM-as-a-Judge对生成任务进行评估也显示,新方法在多个基准测试上均优于其他方法(表9)。研究还探究了学生模型大小的影响,在100M到3B的不同学生规模上,随机采样KD相比CE带来的下游任务性能提升随着学生模型增大而持续增加(图4),这与先前某些工作中报告的小模型性能下降现象形成对比。此外,研究证明一些正交的改进技术(如结合CE损失、对“难”token使用更高学习率的自适应训练)可以与随机采样KD结合,进一步提升性能,甚至在某些指标上超越基础版完整KD(表10)。最后,研究团队将新方法与先前工作进行了直接比较(表11),结果显示新方法在性能上显著优于Raman等人和Peng等人提出的稀疏蒸馏方法。
通过以上工作流程,研究得到了清晰且强有力的结论。随机采样知识蒸馏方法有效地解决了传统top-KD方法中的有偏估计和尾部信息缺失问题。该方法基于重要性采样,能够提供对教师分布的无偏估计,在期望上保持正确的梯度,同时仅需存储极稀疏的logits(如12个token)。实验表明,该方法在不同模型规模(300M至3B)、不同训练长度和多种评估指标上,都能在保证训练速度(相比完整KD显著加速)和极低存储开销(仅需存储约0.01%的教师logits)的前提下,实现与完整知识蒸馏相竞争的性能表现,并且学生模型具有良好的校准性。
本研究的科学价值在于,它从理论和实践两个层面深入剖析了稀疏知识蒸馏的关键瓶颈,并提出了一种理论完备、高效实用的解决方案。它不仅证明了在预训练阶段进行高效离线知识蒸馏的可行性,还揭示了教师分布尾部信息的重要性以及无偏估计对于学生模型校准的关键影响。其应用价值巨大,使得研究机构和企业能够以可承受的存储和计算成本,利用大型教师模型预先生成的知识来高效训练一系列不同规模的小型学生模型,极大地促进了大型语言模型的压缩、部署和生态发展。该研究也为后续探索更复杂的采样方案或其他形式的稀疏蒸馏提供了坚实的基础。
本研究的亮点突出。重要发现在于首次通过理论证明并实证验证了top-KD方法导致性能下降和校准不佳的根本机制——有偏估计与尾部监督缺失。方法新颖性体现在创造性且简洁地应用重要性采样原理来解决稀疏知识蒸馏问题,提出的“随机采样KD”方法兼具理论优雅性和实践高效性。研究目标的特殊性在于聚焦于LLM预训练这一更具挑战性且存储成本巨大的场景,而非仅局限于微调或后训练阶段。结果的广泛性则通过涵盖从合成实验、小规模验证到大规模公开数据集的系统评估,全面证明了方法的有效性和鲁棒性。最后,与同期独立工作(如Gemma 3技术报告)的结论相互印证,也增强了该方法的可信度与普适性。论文末尾也坦诚地讨论了研究的局限性,如受计算资源所限未在更大模型和更长时间训练上进行实验,以及因缓存限制未探索表征匹配蒸馏等,为未来研究指明了方向。