Press "Enter" to skip to content

如何跳出魔改模型?

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

©PaperWeekly 原创 · 作者|燕皖

 

单位|渊亭科技

 

研究方向|计算机视觉、CNN

 

刚刚和小伙伴参加完 kaggle 的 Global Wheat Detection 比赛获得了 Private Leaderboard 第七的名次,首先,在这次比赛中我们发现在 Public Leaderboard 所得到成绩和 Private Leaderboard 所得到的成绩有很大的差异,其次, 我们还发现了一些除魔改模型之外对涨点有效的方法。 这是我们成绩排名截图。下面就具体看看这两种方法。

 

 

 

Data argument

 

在训练神经网络时,我们常常会遇到的一个只有小几百数据,然而,神经网络模型都需要至少成千上万的图片数据。因此, 为了获得更多的数据,我们只要对现有的数据集进行微小的改变。

 

比如翻转(flips)、平移(translations)、旋转(rotations)等等。而我们要介绍的是 MixMatch,可以看做是半监督学习下的 mixup 扩增。

 

 

论文标题: MixMatch: A Holistic Approach to Semi-Supervised Learning

 

论文链接: https://arxiv.org/pdf/1905.02249.pdf

 

代码链接: https://github.com/google-research/mixmatch

 

对于许多半监督学习方法,往往都是增加了一个损失项,这个损失项是在未标记的数据上计算的,以促进模型更好地泛化到训练集之外的数据中。一般地,这个损失项可分为三类:

 

熵最小化 ——它鼓励模型对未标记的数据输出有信心的预测;

 

一致性正则化 ——当模型的输入受到扰动时,它鼓励模型产生相同的输出分布;

 

泛型正则化 ——这有助于模型很好地泛化,避免对训练数据的过度拟合。

 

MixMatch 整合了前面提到的一些 ideas 。对于给定一个已经标签的 batch X 和同样大小未标签的 batch U,先生成一批经过 Mixup 处理的增强标签数据 X’ 和一批伪标签的 U’,然后分别计算带标签数据和未标签数据的损失项。具体地流程如下:

 

将有标签数据 X 和无标签数据U混合在一起形成一个混合数据 W,然后有标签数据 X 和 W 中的前 X 个进行 mixup 后,得到的数据作为有标签数据 X’ ,同样,无标签数据和 W 中的后 U个进行 mixup 后,得到的数据作为无标签数据 U’。

 

损失函数:对于有标签的数据,使用交叉熵;“guess”标签的数据使用 MSE;然后将两者加权组合。如下:

 

 

MixMatch 就是将无监督和有监督的数据分开进行 mixup 增强,然后无监督的 loss 使用的是 MSE。在比赛中,我们发现如果有监督和无监督一起进行 mixup,性能会下降,而分开进行 mixup 增强,则会进一步提升。

 

 

Semi-Supervised Learning

 

尽管 SSL 取得显着进展,但 SSL 方法主要应用于图像分类,今天介绍 一种用于目标检测的 SSL,称为 STAC。

 

 

论文标题: A Simple Semi-Supervised Learning Framework for Object Detection

 

论文链接: https://arxiv.org/pdf/2005.04757.pdf

 

代码链接: https://github.com/google-research/ssl_detection/

 

这篇文章利用了 Self-training和 Augmentation driven Consistency regularization,所以称为 STAC。具体训练步骤如下:

 

在可用的标签图像上训练教师模型。

 

生成未标记图像的伪标签(即边界框和他们的类别标签)。

 

将  strong data augmentations  应用于未标记的图像,并进行伪标签的转换。

 

计算无监督损失和有监督损失以训练检测器。

 

 

现在就看 SSL 的另一个关键点——未标记数据的无监督的损失函数:

 

 

其中,ls 是有监督的损失函数,lu 是无监督的损失函数, A 是应用于未标记图像的强数据增强,p 和 s 是类别,t 和 q 是边框坐标。

 

将 data augmentations 应用于半监督学习的方法在很早就有文献提出,其背后的思想是 Consistency Regularization,即使对未标记的示例进行了增强,分类器也应该输出相同的类分布。

 

具体地,一致性正则化强制未标记的样本 x 应该与增强后的样本 Augment(x) 保持一致,其中 Augment 是一个随机数据增强函数,例如:随机空间平移或添加噪声。而本文实验发现 λu ∈ [1 , 2] 的时候效果最好。说明了半监督和有监督的重要性是不一样的。

 

 

 

Global Wheat Detection

 

这里还是先介绍一下小麦头检测的比赛的内容:

 

比赛链接:

 

https://www.kaggle.com/c/global-wheat-detection/overview/code-requirements

 

比赛背景: 主要是准确估计算出不同品种的小麦头的密度和大小,从而帮助农民评估自己的农作物

 

比赛要求: 检测并框出图片中的小麦头,评估方式是 MAP,MAP,主要是权衡 precision 和 recall 的一个指标。截止时间 8 月 4 号,提交要求不能联网并且 CPU Notebook <= 9 hours run-time,GPU Notebook <= 6 hours run-time

 

数据集: 训练集为 3434 张小麦图片,在 Public Leaderboard 上计算成绩的测试集占总的测试集的 62%,而在最终计算 Private Leaderboard 成绩的测试集为占中的测试集的 38%。

 

3.1 赛题的难度

 

小麦外观会因成熟度,颜色,基因型和头部方向而异,因此对模型的泛化能力要求比较高。

 

一张图片中小麦头数量很多密度很大,因此常常出现小麦头重叠的情况。

 

训练数据少,并且部分图片模糊和大小不一致。

 

3.2 解决思路

 

在训练阶段通过对图像进行增强来数据扩增训练出多个模型,然后在测试集上进行半监督学习。最后,在检测时利用 TTA(Test time augmentation)增加检测的准确性,并利用 wbf 融合多个模型的结果。

 

3.2.1 训练数据扩增

 

由于在训练的数据量较小(容易过拟合),而并且测试集的分布比较分散,对模型泛化能力要求比较高,因此采取对图像增强的方式对训练集进行扩增,采取的图像扩增的方式有图像的缩放、随机水平翻转和垂直翻转、多个图像的拼接、色彩空间 hsv 增强,通过这个方式训练集扩增了 5 倍以此缓解训练数据量小的问题。

 

3.2.2 半监督训练

 

伪标签对成绩的提升有很大的帮助,最初在 Public Leaderboard 上没加入伪标签技术成绩:0.7522 , 加入伪标签技术后成绩为:0.7720  ,增加了 0.0198,排名提升了一百多名效果可以说是相当的明显了。

 

具体地,我们对图像的增强策略包括 Vertical Flip,HorizontalFlip,Rotate90,180,270,Multi-Scale 0.83 and 1.2 ,cutout,mixup,然后利用在训练集训练好的模型对未标记的测试集图片进行伪标签制作。

 

最开始,我们也仅仅是增加这些 argument,能够达到  0.7720 ,进一步使用 MixMatch 和 STAC 的方法后,分别能够达到 0.7734 和 0.7751。

 

 

Acc

Baseline 0.7593
Rotate90,180,270 0.7640
Vertical/Horizontal Flip 0.7682
Multi-Scale 0.83 and 1.2 0.7720
MixMatch 0.7734
STAC(λu=1.4) 0.7751

 

3.3.3 模型检测过程

 

在检测的过程中使用了 TTA(Test time augmentation),对原始图像进行旋转(90°,180°,270°)、垂直水平翻折、图像缩放(放大 1.2 倍,缩小 0.87 倍),然后对 TTA 后的图像进行检测,最终将所得到的 box 进行 nms。

 

采用 TTA(测试时增强),可以对一幅小麦图像做多种变换,创造出多个不同版本,对多个版本数据进行计算最后得到平均输出作为最终结果,提高了结果的稳定性和精准度。

 

3.3 Private Leaderboard

 

在这次比赛中最终提交的两个方案中,方案一也就是上面使用的方案取得了 Private Leaderboard 第七的成绩,方案二:增加了根据验证集计算成绩自动选择最好的阈值,对于伪标签的训练 epoch 增加到 15,而减少了半监督训练中的 Argument(只剩下了旋转)。

 

方案一在 Public Leaderboard 表现一般的方案成绩为 0.7721 排在 55 名,但是却在 Private Leaderboard 排在了第七名。方案二在 Public Leaderboard 上成绩还不错的方案 0.7751 在排名在 23,但是在 Private Leaderboard上37% 的测试集我的成绩却为 0.6954 排在了 300 多名。

 

 

写在最后

 

由于本次比赛的数据集较小,很容易导致过拟合的现象。比赛结束的时候发现 Public leaderboard 成绩还不错,但是当 Private Leaderboard 出来后排名一落千丈,相比较而言,数据量大了的比赛绝大部分人排名都没有变化,少数有 1~2 名的浮动在。

 

在这次比赛里的方案二由于 Public Leaderboard 上测试集占  62%,测试集样本较多,因此增加伪标签的训练使得它在 Public Leaderboard 上的成绩增加很多,但是方案二发生了过拟合使得在 Private Leaderboard 上的成绩下降就很明显。

 

因此,深度学习网络训练到什幺时候停止?在关注训练集数量、质量以及分布等等因素的同时,更应该测试集(实际场景)的情况。否则常常会出现悲惨结局。另外,除了魔改模型,数据增强和半监督都是跳出魔改模型的好方法,能够使得模型获得更多的泛化能力。

 

Be First to Comment

发表评论

电子邮件地址不会被公开。 必填项已用*标注