Press "Enter" to skip to content

理解深度学习泛化的新视角

理解泛化是深度学习中尚未解决的基本问题之一。为什幺在有限的训练数据集上优化模型会在保留的测试集上获得良好的性能?这个问题在机器学习中得到了广泛的研究,其悠久的历史可以追溯到 50 多年前。现在有许多数学 工具可以帮助研究人员理解某些模型中的泛化。不幸的是,这些现有理论中的大多数在应用于现代深度网络时都失败了——它们在现实环境中既空洞又不可预测。这种理论与实践之间的差距最大的是过度参数化模型,理论上有能力过度拟合其训练集,但在实践中通常不会。

 

在ICLR 2021接受的“ The Deep Bootstrap Framework: Good Online Learners are Good Offline Generalizers ”中,我们提出了一个新框架,通过将泛化与在线优化领域联系起来来解决这个问题。在典型的设置中,模型在有限的样本集上进行训练,这些样本可重复用于多个时期。但是在在线优化中,模型可以访问无限 样本流,并且可以在处理此流时迭代更新。在这项工作中,我们发现在无限数据上快速训练的模型与在有限数据上训练时可以很好地泛化的模型相同。这种联系为实践中的设计选择带来了新的视角,并为从理论角度理解泛化奠定了路线图。

 

Deep Bootstrap 框架 Deep Bootstrap 框架

 

的主要思想是将训练数据有限的现实世界与数据无限的“理想世界”进行比较。我们将这些定义为:

 

真实世界 (N, T):在来自分布的N 个训练样本上训练模型,对于T 小批量随机梯度下降 (SGD) 步骤,像往常一样在多个时期重复使用相同的N 个样本。这对应于在经验损失(训练数据损失)上运行 SGD,并且是监督学习中的标准训练程序。

 

理想世界 (T):为T步训练相同的模型,但在每个 SGD 步中使用分布中的新鲜样本。也就是说,我们运行完全相同的训练代码(相同的优化器、学习率、批量大小等),但在每个 epoch 中采样一个新的训练集而不是重复使用样本。在这个理想的世界设置中,具有有效无限的“训练集”,训练误差和测试误差之间没有区别。

 

 

先验地,人们可能期望现实世界和理想世界可能彼此无关,因为在现实世界中,模型从分布中看到有限数量的示例,而在理想世界中,模型看到的是整个分布。但在实践中,我们发现真实模型和理想模型实际上有相似的测试误差。

 

为了量化这一观察结果,我们通过创建一个新的数据集来模拟一个理想的世界环境,我们称之为CIFAR-5m。我们在CIFAR-10上训练了一个生成模型,然后我们用它生成了大约 600 万张图像。选择数据集的规模是为了确保从模型的角度来看它是“几乎无限”的,这样模型就不会重新采样相同的数据。也就是说,在理想情况下,模型会看到一组全新的样本。

 

 

下图展示了几种模型的测试误差,比较了它们在真实世界设置(即重用数据)和理想世界(“新鲜”数据)中对 CIFAR-5m 数据进行训练时的性能。蓝色实线显示了现实世界中的ResNet模型,使用标准 CIFAR-10 超参数对 50K 样本进行了 100 轮训练。蓝色虚线显示了理想世界中的相应模型,单次通过 500 万个样本进行训练。令人惊讶的是,这些世界有非常相似的测试错误——模型在某种意义上“不在乎”它看到的是重复使用的样本还是新鲜的样本。

 

 

这也适用于其他架构,例如多层感知器(红色)、视觉转换器(绿色),以及架构、优化器、数据分布和样本大小的许多其他设置。这些实验提出了泛化的新视角:快速优化(在无限数据上)、泛化良好(在有限数据上)的模型。例如,ResNet 模型在有限数据上的泛化能力优于 MLP 模型,但这是“因为”即使在无限数据上,它的优化速度也更快。

 

从优化行为理解泛化

 

关键观察是真实世界和理想世界模型在所有时间步长的测试误差中保持接近,直到真实世界收敛(< 1% 训练误差)。因此,人们可以通过研究模型在理想世界中的相应行为来研究现实世界中的模型。

 

这意味着模型的泛化可以从其在两个框架下的优化性能来理解:

 

1.在线优化:理想世界测试错误减少的速度有多快 2.离线优化:真实世界的训练误差收敛速度有多快

 

因此,为了研究泛化,我们可以等效地研究上述两个术语,这在概念上可以更简单,因为它们只涉及优化问题。基于这一观察,好的模型和训练程序是那些 (1) 在理想世界中快速优化和 (2) 在现实世界中优化不太快的模型和训练程序。

 

深度学习中的所有设计选择都可以通过它们对这两个术语的影响来查看。例如,像一些进展回旋,跳跃的连接,并预先-训练的帮助主要是通过加速理想世界的优化,而像其他进步正规化和数据增强的帮助主要是由减速真实世界的优化。

 

应用 Deep Bootstrap 框架

 

研究人员可以使用 Deep Bootstrap 框架来研究和指导深度学习中的设计选择。原则是:每当做出影响现实世界中泛化(架构、学习率等)的更改时,应考虑其对(1)测试错误的理想世界优化(越快越好)和(2) 现实世界中训练误差的优化(越慢越好)。

 

例如,在实践中经常使用预训练来帮助小数据机制中模型的泛化。然而,预训练有帮助的原因仍然知之甚少。可以使用 Deep Bootstrap 框架通过查看预训练对上述 (1) 和 (2) 项的影响来研究这一点。我们发现预训练的主要作用是改进理想世界优化(1)——预训练将网络变成了在线优化的“快速学习者”。因此,预训练模型的改进泛化几乎完全被它们在理想世界中的改进优化所捕获。下图显示了在CIFAR-10上训练的Vision-Transformers (ViT),比较在ImageNet上从头开始训练与预训练。

 

 

还可以使用此框架研究数据增强。理想世界中的数据增强对应于对每个新鲜样本进行一次增强,而不是对同一样本进行多次增强。这个框架意味着好的数据增强是那些(1)不会显着损害理想世界优化(即,增强样本看起来不太“不分布”)或(2)抑制现实世界优化速度(所以真实世界需要更长的时间来适应它的火车集)。

 

数据增强的主要好处是通过第二项,延长现实世界的优化时间。至于第一项,一些激进的数据增强(mixup / cutout)实际上会损害理想世界,但这种影响与第二项相比相形见绌。

 

结束语

 

Deep Bootstrap 框架为深度学习中的泛化和经验现象提供了一个新视角。我们很高兴看到它在未来应用于理解深度学习的其他方面。特别有趣的是,泛化可以通过纯粹的优化考虑来表征,这与理论上的许多流行方法形成对比。至关重要的是,我们同时考虑了在线和离线优化,这两个方面单独不足,但共同决定了泛化。

 

Deep Bootstrap 框架还可以阐明为什幺深度学习对许多设计选择相当稳健:多种架构、损失函数、优化器、归一化和激活 函数可以很好地泛化。该框架提出了一个统一原则:本质上,任何在在线优化设置中运行良好的选择也将在离线设置中很好地泛化。

 

最后,现代神经网络可以是参数化过度(例如,在小 数据 任务上训练的大型网络)或参数化不足(例如,OpenAI 的 GPT-3、谷歌的 T5或Facebook 的 ResNeXt WSL)。Deep Bootstrap 框架意味着在线优化是在这两种机制中取得成功的关键因素。

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注