Press "Enter" to skip to content

GroupSoftmax:利用COCO和CCTSDB训练83类检测器

通常,CV算法工程师在利用YOLO,Faster RCNN,CenterNet等一系列detection算法对公司的业务数据一顿猛train,看到自己的模型用在了业务上,陶醉在CEO对自己的工作很满意、即将升职加薪的幻觉之中时,往往迎来的是CEO的灵魂拷问。 考虑如下问题:

 

 

CEO总是逼我,说客户关心某一个新的类别 :对于一个n类检测数据集和一个m类检测数据集,想要得到一个n+m类检测模型。实际生产环境中,已经标注完成的数据集,因为业务需要增加k个新的检测类别。但不希望对已有数据重新标注,而是只标注接下来新的数据。

 

CEO总是逼我,说客户希望把某一类别分成两种情况对待 :上一个问题再进一步推广,因为业务需要,数据的标注标准发生改变,比如将过去某个类别拆分为多个新的类别。但不希望对已有数据重新标注,而是只标注接下来新的数据。

 

CEO总是逼我,说标注成本太高,公司要倒闭 :对于一个长尾分布的数据集,标注成本往往是巨大的,我们希望收集那些少数类别样本,但是在标注时遇到了多数类别样本,也要花很大的代价进行标注。是否有办法能够对于那些简单的样本类别,选择不进行标注,只标注有价值的少数类别目标。

 

 

为了讨好CEO,我们设计了GroupSoftmax交叉熵损失函数,能够有效解决上述3个问题。如下图所示,GroupSoftmax交叉熵损失函数允许类别 和类别 发生合并,形成一个新的组合类别  ,当训练样本  的真实标签为组合类别 时,也能够计算出类别 和类别 的对应梯度,完成网络权重更新。 理论上,GroupSoftmax交叉熵损失函数能够兼容任意数量、任意标注标准的多个数据集联合训练。

 

 

softmax交叉熵损失函数和groupsoftmax交叉熵损失函数的梯度对比图,在groupsoftmax中,类别k和类别j能够进行组合,形成一个新的类别g,以此计算相应的梯度

 

我们利用了80类检测数据集COCO和3类检测数据集CCTSDB联合训练,基于Faster RCNN算法(SyncBN),联合训练得到了一个83类检测器,在coco_minival2014测试集上,GroupSoftmax交叉熵损失函数和原始的Softmax交叉熵损失函数训练效果相比,mAP由原来的38.6上升到了39.3,也就是说我们利用了一个与COCO无关的CCTSDB数据集,将检测指标提高了0.7个点,还同时能够完成更多的类别检测任务,这算是比较理想的。此外,我们还训练了一个trident*模型,6个epoch在coco_minival2014测试集上的mAP为44.0,由此可见GroupSoftmax交叉熵损失函数是切实有效的。理论上而言,利用GroupSoftmax交叉熵损失函数,可以无限添加不同标注标准的数据集,进行联合训练。

GroupSoftmax和Softmax对比训练试验,在不降低甚至提高检测效果的同时,将检测类别数量由80增加到83

利用训练好的Faster RCNN模型,在一张给定的图片上,即检测到了COCO数据集中的类别(车辆),也检测到了CCTSDB数据集中的类别(交通标志)

我们基于SimpleDet检测框架,实现了mxnet版本的GroupSoftmax交叉熵损失函数,源码地址为: https://github.com/chengzhengxin/groupsoftmax-simpledet ,欢迎试用。下面详述GroupSoftmax交叉熵损失函数的工作原理。

 

GroupSoftmax交叉熵损失函数的推导

 

翻开任意一篇介绍softmax交叉熵损失函数的文章,都能看到,损失 对激活值  的梯度为:

 

 

一般地,我们采用交叉熵损失函数处理分类问题,使用式(1)中得到的梯度,已经能够满足识别分类等算法任务的训练。 但是在真实情况中,我们有时候无法确定类别 给出对应的   ,因为不同的数据集之间的分类标准不同,导致类别定义之间的差异性。 比如在数据集A中,类别   为自行车,在数据集B中,类别 为电动车,在数据集B中,类别   为非机动车。 也即数据集A中的 和 ,在数据集B中合并成为了一个新的类别 ,此时Softmax交叉熵损失函数受限,无法支持正常训练。 为此提出了GroupSoftMax交叉熵损失函数。

 

GroupSoftmax交叉熵损失函数的定义为如下,为群组 的组合概率的交叉熵:

 

 

式(2)中, 表示一个群组类别(多个类别的组合),其组合概率 可以表示为:

 

 

如上文提到的,在数据集B中,类别 为非机动车,该类别即为一个群组类别,其由数据集A中的两个类别组成,分别为数据集A中的   自行车和 电动车。考虑式(3)中的情况,当 时,也即目标类别 属于当前群组类别 时,有:

 

 

同理,考虑式(3)中的情况,当   时,也即目标类别 不属于当前群组类别   时,有:

 

 

由式(2)、式(4)、式(5),可以得到GroupSoftMax交叉熵损失函数对激活值   的梯度为 :

 

 

式(6)中, 表示训练时真实类别群组标签,从式(6)中可以看出,如果数据集B中的类别标签为非机动车时,此时电动车类别的梯度为:

 

 

可以看到,对比式(1)和式(6),得出的结论非常的make sense,对于一个群组类别中的子类别而言,其对应的梯度为群组类别的梯度乘以相应的权重,权重取值为当前子类别的预测概率  与群组类别的预测概率  的比值,其中群组类别的预测概率  等于多个子类别的预测概率之和。从式(1)和式(6)可以看出,当群组类别 中只包含 单独一个类别时,GroupSoftmax损失函数退化为Softmax损失函数,也即可以认为GroupSoftmax损失函数是Softmax损失函数的一种推广,一种更复杂也更加灵活的表达,可以自由的发生类别合并。

 

工程实现需要注意的细节

 

1、对于某一个数据集中的未进行标注的类别,可以理解为和背景一起作为新的群组类别。

 

2、在two-stage检测算法中,用于提取proposal的RPN网络通常是2分类网络,因为只用于区分前景和背景,但是对于某些类别未标注的数据集,是无法正确区分前景和背景的。此时需要将RPN网络修改为多分类。比如COCO+CCTSDB联合训练时,COCO中是一种前景,CCTSDB中是另外一种前景,所以此时的RPN应该修改为3分类,如下图所示:

COCO数据集对应rpnv1,第2种前景未进行标注,此时跟背景组成一个组合类别。CCTSDB数据集对应rpnv2,第1种前景未进行标注,此时跟背景组成一个组合类别。

3、COCO(80)+CCTSDB(3)联合训练时,最终的分类任务为1+83类,对于某个数据集中未标注的类别,比如COCO中未标注的3类,可以和背景类组成一个组合类别。如下图所示:

COCO标注了80类,后面3类未进行标注,所以后3类group信息为0,与背景类组成一个组合类别。CSTSDB标注了3类,前面80类未进行标注,所以前80类group信息为0,与背景类组成一个组合类别。

4、编写CUDA代码时,计算群组类别的概率  时,需要加上一个微小量  ,避免分母为0带来计算出错的情况。

 

GroupSoftmax的CUDA代码请参考:

 

https://github.com/chengzhengxin/groupsoftmax-simpledet/blob/master/operator_cxx/contrib/group_softmax_output.cu

作者:程萝卜

链接:https://zhuanlan.zhihu.com/p/73162940

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注