Press "Enter" to skip to content

TracIn:一种估计训练数据影响的简单方法

机器学习 (ML) 模型训练数据的质量可能对其性能产生重大影响。数据质量的一种衡量标准是影响的概念,即给定训练示例影响模型及其预测性能的程度。虽然影响力对于机器学习研究人员来说是一个众所周知的概念,但深度学习模型背后的复杂性,加上其不断增长的规模、特征和数据集,使得影响力的量化变得困难。

 

最近提出了一些方法来量化影响。有些依赖于在删除一个或多个数据点的情况下重新训练时准确度的变化,有些使用已建立的统计方法,例如估计扰动输入点影响的影响函数或将预测分解为训练示例重要性加权组合的表示方法. 还有其他方法需要使用额外的估计器,例如使用强化学习的数据评估。尽管这些方法在理论上是合理的,但它们在产品中的使用受到了大规模运行它们所需的资源或它们给培训带来的额外负担的限制。

 

在NeurIPS 2020 上作为焦点论文发表的“通过跟踪梯度下降估计训练数据影响”中,我们提出了TracIn,这是一种简单的可扩展方法来应对这一挑战。TracIn 背后的想法很简单——跟踪训练过程以在访问单个训练示例时捕获预测的变化。TracIn 可有效地从各种数据集中查找错误标记的示例和异常值,并且通过为每个训练示例分配影响分数,可用于根据训练示例(而不是特征)解释预测。

 

TracIn

 

深度学习算法的基本思想通常使用称为随机梯度下降(SGD)的算法或其变体进行训练。SGD 通过对数据进行多次传递并对模型参数进行修改,从而在每次传递中局部减少损失(即模型的目标)。下图中的图像分类任务演示了一个示例,其中模型的任务是预测左侧测试图像的主题(“西葫芦”)。随着模型在训练中取得进展,它会接触到影响测试图像损失的各种训练示例,其中损失 是预测分数和实际标签的函数——西葫芦的预测分数越高,损失越低。

 

 

假设在训练时已知测试样例,并且训练过程一次访问每个训练样例。在训练期间,访问特定的训练示例会更改模型的参数,然后该更改将修改测试示例的预测/损失。如果可以在整个过程中跟踪训练示例,那幺测试示例的损失或预测的变化可以归因于所讨论的训练示例,其中训练示例的影响将是对训练示例的访问的累积归因.

 

有两种类型的相关训练示例。那些减少损失的,如上面的西葫芦的形象,被称为支持者,而那些增加损失的,如安全带的形象,被称为反对者。在上面的示例中,标记为“太阳镜”的图像也是一个支持者,因为它在图像中带有安全带,但被标记为“太阳镜”,从而促使模型更好地区分西葫芦和安全带。

 

在实践中,测试示例在训练时是未知的,这个限制可以通过使用学习算法输出的检查点作为训练过程的草图来克服。另一个挑战是学习算法通常一次访问多个点,而不是单独访问,这需要一种方法来解开每个训练示例的相对贡献。这可以通过应用逐点损失梯度来完成。这两种策略一起捕获了 TracIn 方法,该方法可以简化为测试和训练示例的损失梯度的点积的简单形式,由学习率加权,并跨检查点求和。

 

 

或者,可以改为检查对预测分数的影响,如果测试示例没有标签,这将很有用。这种形式只需要用预测梯度替换测试示例中的损失梯度。

 

计算顶级影响示例

 

我们首先通过计算一些训练数据的损失梯度向量和特定分类的测试示例(变色龙的图像)来说明 TracIn 的效用,然后利用标准的k-最近邻库来检索顶部支持者和反对者。顶级对手说明变色龙的混血能力!为了进行比较,我们还展示了k 个最近邻,以及来自倒数第二层的嵌入。支持者是不仅相似而且属于同一类的图像,反对者是相似的图像但属于不同的类。请注意,没有明确规定支持者或反对者是否属于同一类。

 

 

聚类

 

由 TracIn 给出的将测试样例的损失简单分解为训练样例的影响也表明,任何基于梯度下降的神经模型的损失(或预测)都可以表示为梯度空间中相似性的总和。最近的工作表明,这种函数形式类似于内核的函数形式,这意味着这里描述的这种梯度相似性可以应用于其他相似性任务,如聚类。

 

在这种情况下,TracIn可以是作为所用的相似功能一个内聚类算法。为了限制相似性度量以便将其转换为距离度量(1 – 相似性),我们将梯度向量归一化为具有单位范数。下面,我们在西葫芦图像上应用 TracIn 聚类以获得更精细的聚类。

 

 

识别具有自我影响的异常值

 

最后,我们还可以使用 TracIn 来识别表现出高度自我影响的异常值,即训练点对其自身预测的影响。当示例被错误标记或罕见时会发生这种情况,这两种情况都会使模型难以对示例进行泛化。以下是一些自我影响力高的例子。

 

 

 

应用程序

 

除了使用 SGD(或相关变体)进行训练外,没有其他要求,TracIn 与任务无关,适用于各种模型。例如,我们使用 TracIn 研究了深度学习模型的训练数据,该模型用于解析对 Google 智能助理的查询,即“将我的闹钟设置为早上 7 点”类型的查询。我们很感兴趣地看到查询“禁用我的警报”的最大对手是“禁用我的计时器”,同时设备上的警报也处于活动状态。这表明助理用户经常互换“计时器”和“闹钟”这两个词。TracIn 帮助我们解释了 Assistant 数据。

 

更多的例子可以在论文中找到,包括一个结构化数据的回归任务和一些文本分类任务。

 

结论

 

TracIn 是一种简单、易于实施、可扩展的方法,用于计算训练数据示例对单个预测的影响或查找稀有和错误标记的训练示例。对于该方法的实现参考,您可以在论文中链接的 github 中找到指向图像代码示例的链接。

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注