学术报告:基于梯度追踪的训练数据影响力评估方法
作者与机构: 本研究的主要作者包括 Garima (谷歌)、Frederick Liu (谷歌)、Satyen Kale (谷歌) 以及 Mukund Sundararajan (谷歌,通讯作者)。 发表信息: 该研究于2020年发表在 *34th Conference on Neural Information Processing Systems (NeurIPS 2020)*。
学术背景: 本研究隶属于机器学习,特别是深度学习模型可解释性与训练数据分析领域。随着深度学习在诸多现实问题中取得成功,模型的性能高度依赖于其训练数据的质量。理解单个训练样本如何影响模型对特定测试样本的预测,对于改善数据集(如识别和修正错误标注样本)、理解模型决策过程以及进行数据审计等至关重要。然而,现有方法(如基于影响函数的方法和基于表示点的方法)在实现复杂性、对模型最优状态的依赖以及可扩展性方面存在局限。因此,本研究旨在提出一种简单、通用且可扩展的方法,来量化训练样本对测试样本预测的影响力。
研究流程详解:
本研究提出了一种名为 TracIn 的新方法,其核心思想是通过追踪梯度下降过程来计算训练样本的影响力。其工作流程主要围绕以下关键步骤和定义展开:
核心理念与理想化定义:
- 研究出发点类比于微积分基本定理,即函数在两点间的变化可以通过其在该区间内路径上的梯度积分来近似。
- 作者首先定义了一个“理想化影响力”:对于一个给定的测试样本
z',一个训练样本 z 对其的理想化影响力,等于在模型的整个训练过程中,每当用到训练样本 z 进行参数更新时,测试样本 z' 的损失值减少量的总和。公式表达为:TracIn_Ideal(z, z') = Σ_{t: z_t=z} [L(w_t, z') - L(w_{t+1}, z')]。其中,t 代表训练迭代步,z_t 是第 t 步使用的训练样本,w_t 是第 t 步开始时的模型参数。 - 该定义具有一个良好性质:所有训练样本对一个测试样本的影响力之和,等于该测试样本损失从训练开始到结束的总体减少量。
一阶近似与实际算法(TracIn):
- 直接计算理想化影响力不切实际,因为它需要存储整个训练过程中每一步的模型状态和样本使用记录。为此,研究提出了一阶近似。
- 利用梯度下降的小步长特性,作者使用一阶泰勒展开来近似单次参数更新导致的损失变化:
L(w_t, z') - L(w_{t+1}, z') ≈ η_t * ∇L(w_t, z') · ∇L(w_t, z_t)。其中,η_t 是第 t 步的学习率,∇L 表示损失函数对模型参数的梯度。 - 基于此,TracIn 影响力被定义为训练样本
z 在所有被使用时对应的上述内积项之和:TracIn(z, z') = Σ_{t: z_t=z} η_t * ∇L(w_t, z') · ∇L(w_t, z)。 - 该方法自然地扩展到小批量(minibatch)训练场景。对于批量中的每个样本,其贡献被平均分配到该批量对测试样本的总影响力上。
面向实际应用的启发式实现(TracInCP):
- 为了使其可扩展至长时间的训练过程,研究提出了基于检查点(checkpoints)的实用启发式算法 TracInCP。
- 标准的训练流程通常会定期保存模型参数(检查点)。TracInCP 假设在两个检查点之间,每个训练样本被恰好访问一次(为近似 Lemma 3.1,但非必需),且学习率恒定。
- 由于无法获得每个样本被访问时的精确参数,TracInCP 使用访问后遇到的第一个检查点的参数进行近似计算。最终,TracInCP 公式为:
TracInCP(z, z') = Σ_{i=1}^{k} η_i * ∇L(w_{t_i}, z) · ∇L(w_{t_i}, z')。其中,k 是检查点数量,w_{t_i} 是第 i 个检查点的参数,η_i 是该检查点期间的学习率。 - 该方法简单且通用,仅需梯度计算、检查点存取和损失函数,不依赖于模型是否达到最优状态,适用于任何使用(随机)梯度下降及其变体训练的模型。
对比实验与评估方法:
- 研究将 TracIn 与两种主流方法进行对比:影响函数 和 表示点选择 方法。研究详细描述了这两种方法的原理、公式和计算复杂性。
- 为了评估不同方法识别“错误标注样本”的有效性,研究采用了自影响力评估法:计算每个训练样本对其自身损失的影响力(即令
z' = z),然后按影响力降序排列训练样本。 - 理论依据是,被错误标注的样本通常是其自身的强“支持者”(proponents,即正影响力),因为它们会降低模型对其(错误)标签的损失。
- 评估指标是:在不同比例的排序前列样本中,识别出的错误标注样本的比例。该比例越高,说明方法越有效。
- 实验在多个数据集上进行:CIFAR-10 (图像分类,使用 ResNet-56)、MNIST (手写数字分类) 以及 ImageNet (大规模图像分类,使用 ResNet-50) 和文本分类任务(DBPedia 数据集)。研究中还涉及了一个房价回归问题(California Housing)。
计算优化与实现技巧:
- 随机投影:为了应对大规模模型(参数众多)带来的高维梯度内积计算开销,研究提出使用随机投影技术。通过将高维梯度投影到低维空间,可以高效地计算其内积的无偏估计,从而显著降低计算和存储成本。
- 全连接层快速计算:对于全连接层,其权重的梯度是一个秩为1的矩阵。研究利用这一特性,推导出计算其梯度内积的快速算法,将复杂度从 O(mn) 降低到 O(m+n),并设计了相应的随机投影方法。
主要结果详解:
错误标注样本识别效果:
- 在 CIFAR-10 和 MNIST 数据集上,TracInCP 在识别错误标注样本的任务上显著优于影响函数和表示点方法。
- 图1 (a) 显示,在 CIFAR-10 上,TracInCP 能够在排名前 20% 的训练数据中识别出超过 80% 的错误标注样本,而其他方法在同一位置只能识别不到 50%。使用 TracIn 识别并修正的数据进行重新训练,能带来比其他方法更大的测试准确率提升(左图)。
- 在 MNIST 数据集上,结果类似(图1b),且近似版的 TracIn(使用每个训练步骤)与启发式的 TracInCP(使用检查点)性能接近。
方法特性验证与分析:
- 一阶近似的有效性:在 MNIST 上,研究者验证了损失的实际变化值与 TracIn 的一阶近似值之间存在高度相关性(Pearson 相关系数 0.978,见图7),证明了一阶近似的准确性。
- 检查点的作用:不同检查点捕捉了训练过程不同阶段的信息。例如,图8 显示,在 CIFAR-10 的早期、中期和后期检查点,识别出的错误标注样本在不同类别上的分布不同,表明模型在不同训练阶段学习不同类别的特征。这表明使用多个检查点比仅用最终模型更重要。研究还发现,选择损失下降较快阶段的检查点比均匀采样的检查点更具信息性(图2)。
- 支持者与反对者的可视化分析:通过可视化图像分类任务中测试样本的强影响力训练样本(支持者和反对者),研究发现 TracIn 倾向于找到与测试样本在像素层面视觉相似的支持者和反对者,而表示点方法有时会找到视觉上不相似的反对者(图3,图5,图6)。这提供了更直观、合理的解释。
在各类任务中的应用展示:
- 回归任务(房价预测):研究者使用加州房价数据集展示了 TracIn 的洞察力。例如,可以为帕洛阿尔托的房屋找到模型层面的“可比物业”,这些支持者来自湾区的其他地区以及萨克拉门托、洛杉矶等城市。高“自影响力”的样本通常来自人口密集区域,暗示模型可能存在“记忆”现象;而低自影响力样本来自稀疏区域,模型更依赖泛化(图10)。
- 文本分类(DBPedia):在 DBPedia 文本分类任务中,TracIn 可以揭示数据中的潜在关联。例如,对于一个政治家(Manuel Azaña)的测试样本,其反对者(opponents)列表中出现了多位艺术家,这揭示了数据中可能存在政治家与艺术家之间的关联模式(表1)。
- 大规模图像分类(ImageNet):通过应用于 ResNet-50 在 ImageNet 上的全连接层,并利用其梯度秩为1的特性进行加速,证明了 TracIn 能够扩展到百万级图像的大规模数据集。图4展示了为不同测试图像找到的支持者和反对者,这些结果有助于理解模型的混淆模式(例如,教堂的反对者中出现城堡,波士顿梗犬的反对者中出现法国斗牛犬)。
结论与价值:
本研究提出的 TracIn 方法是一种用于计算训练样本对特定预测影响力的创新、通用且实用的工具。 * 科学价值:它从全新的角度(通过追踪整个训练过程的梯度变化)来形式化并量化了训练数据的影响力,无需依赖模型的最优性假设,在概念上与现有方法形成互补。 * 应用价值: 1. 数据质量改进:高效识别错误标注样本,指导数据清洗。 2. 模型可解释性:提供针对单个预测的解释,揭示是哪些训练样本(及其属性)促成了特定预测,帮助理解模型行为。 3. 调试与分析:识别可能引入偏差或导致混淆的训练数据模式(如加州房价中的地理分布模式、文本数据中的跨类别关联)。 4. 潜在扩展应用:如用于主动学习,从少量困难样本扩展出更多相似难例,以增强模型鲁棒性。 * 核心主张:TracIn 的优势在于其简单性(仅需梯度、检查点和损失函数)、通用性(适用于任何基于梯度下降训练的模型,与架构、领域、任务无关)和实用性(易于实现且可扩展)。
研究亮点:
- 创新性方法:首次提出通过“梯度下降轨迹”来定义和计算训练数据影响力的系统框架。
- 显著优于基线:在错误标注识别等核心任务上,性能远超已有的影响函数和表示点方法。
- 实用性强:提出的 TracInCP 启发式算法,利用训练中常见的检查点,使得该方法在实际大规模训练场景中变得可行。
- 广泛的验证:在图像分类(CIFAR-10, MNIST, ImageNet)、文本分类(DBPedia)和回归(房价预测)等多种任务上验证了方法的有效性和洞察力。
- 可扩展技术:为解决大规模计算问题,提出了随机投影和针对全连接层的快速梯度内积算法,确保了方法的工程可行性。
其他有价值的补充: 研究在附录中提供了低延迟实现的思路(利用近似最近邻搜索库),讨论了如何智能选择信息量更大的检查点(而非均匀采样),并对影响函数和表示点方法的公式、复杂性及局限性进行了详细的补充说明。研究还承认,如同任何统计方法一样,TracIn 需要结合人类判断来正确应用,例如选择合适的检查点、网络层,并合理解释其输出结果,将其置于更广泛的分析背景中。