Press "Enter" to skip to content

91.3%!首个将Transformer解码器应用于多标签图像分类的方法Query2Label

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

作者丨镜子

 

 

Query2Label: A Simple Transformer Way to Multi-Label Classification

论文地址: https://arxiv.org/pdf/2107.10834v1.pdf

 

开源地址: https://github.com/SlongLiu/query2labels

 

多标签图像分类(Multi-label image classification)一直以来都是很重要的研究领域,可以从一张图片中获得比单标签分类更丰富的信息,在如图像检索、人像分组、医学图像识别、场景理解等多个领域都有广泛应用。在我们熟悉的ImageNet数据集分类任务上,2021年timm团队提出的训练策略可以将ResNet50模型的分类准确率提升到Top1-80%以上,其中也用到了多标签分类相关的技术(BCELoss+Mixup+CutMix)。

 

1. 简介

 

多标签分类

 

为了照顾到一部分没有相关知识的小伙伴,我先对多标签图像分类做一个简单的回顾。

 

首先,多标签分类(Multi-label classification)是相对于单标签分类(Single-label classification)而言的。一张图片,对应一个标签,这就是单标签分类,我们熟悉的ImageNet就是一个单标签标注的数据集,我们常说的ImageNet上的分类准确率也指的是模型为一张图片预测一个标签的准确率。而多标签分类时,一张图片,对应多个标签,且标签数量不确定,一般我们将存在的标签记为1,不存在的标记为0,当标签一共有N类时,我们会输出一条N维的由0/1组成的向量,每一项对应了一个标签的存在与否。

 

对于多标签分类任务,我们最常用的方式是使用Sigmoid+BCELoss。

 

在单标签分类时,由于模型只需要预测一个结果,我们希望让模型在正确标签上的置信度越高越好,所以采用了SoftMax+CrossEntropy的组合,SoftMax的性质保证了输出结果在0~1之间,且最大的那一项尽量趋近于1,其余项尽量趋近于0,最后通过Argmax得到置信度最高那一项所在的位置;而多标签分类时,我们把全连接层输出的特征通过Sigmoid让每一项都在0~1之间,再通过二分类交叉熵进行优化。

 

问题

 

相较于单标签分类,多标签分类的问题主要在于两点:

 

 

如何处理标签数量差异带来的类别不均衡问题

 

如何区分不同兴趣区域的特征

 

 

问题1在于,我们的batch size是有限的,每张图片对应的标签数也是不确定的,因而当标签类别很多的时候(即图片带有的标签数远小于总标签类别数),会导致这个batch中大部分标签是0,只有很少一部分是1,也就是正负样本不均衡问题。

 

问题2在于,标签对象分布在图片的不同位置上,大小也不一定,我们很难针对性地提取特征,如果我们按照单标签图片特征的提取方式,直接对整张图片提一个特征,会导致有些目标的特征被稀释,比如在图片中较小、较不显着的、画质较差的目标。

 

随着多标签分类研究的发展,目前所提出的方法主要可以归纳为三个方向:

 

 

改进损失函数

 

对标签相关性进行建模

 

定位兴趣区域

 

 

改进损失函数

 

通过改进损失函数来缓解正负样本不均衡问题是常见的做法,在目标检测领域,Kaiming提出的Focal Loss就是在BCELoss的基础上进行修改,通过减少高置信度样本的权重,使得模型在训练时更专注于难样本的学习,在正负样本不均衡的数据中,这种方法可以让模型减少对负样本的过拟合,专注于学习数量较少的正样本。

 

但由于Focal Loss是通过使用同一个参数gamma来调节学习权重,其形式会导致模型在降低简单负样本权重的时候,也会同样减少简单正样本的贡献,换句话说,Focal Loss的本质还是对难易样本的区别对待,对于正负样本不均衡问题并不是百分百适配。因此在ASL(Asymmetric Loss)工作中,对正负样本的权重调节参数gamma进行了解耦,在减弱负样本权重的同时,能保留正样本的贡献能力。

 

建模标签相关性

 

由于一张图片对应多个标签这样的性质,有研究者提出,标签与标签之间是存在相关性的,有些标签大概率会一起出现(比如乒乓球拍和乒乓球,雨伞和人等等),这种先验知识可以被利用起来提升预测的准确率。

 

在过去有研究者使用图卷积网络(GCN)来专门建模这种标签相关性,而在Transformer出现后,其自注意力机制天生就具有相关性建模能力。

 

然而需要注意的是,这种方法的有效性是存在争议的,尤其是当数据规模不够大的时候,这种统计得到的共现关系就可能是虚假的。

 

定位兴趣区域

 

在早期的工作中,大家很自然地想到,可以通过裁剪等方式来将多标签问题简化为多个单标签问题,定位和裁剪方式也五花八门,如BBox、响应区域等,但这些方法的定位准确度不够高,不可避免地会引入背景信息。

 

2. 方法

 

随着Transformer模型在视觉领域的成功,由于其所具有的各种优秀的性质,本文作者将其应用到了多标签图像分类任务中,提出了Query2Label方法,使用Transformer解码器来查询每个标签的存在性,由于其框架简单且性能强劲,在多个公开数据集上取得了SOTA,在比较具有代表性的MS-COCO数据集上,2020年的SOTA方法mAP为88.4%,而本文取得了91.3% 。

 

本文的贡献在于:

 

 

本文是第一个使用Transformer解码器结构在分类任务中的工作。

 

实验显示了Transformer解码器中的交叉注意力模块可以自适应地提取目标特征,配合多头注意力机制能进一步学习目标的不同视角、不同部位,从而来带了更好的性能。

 

在多个公开数据集上实验证明了本文方法的有效性和优越性。

 

 

Query2Label

 

首先框架上,Query2Label是一个两阶段框架,第一阶段将图片通过一个骨干网络提取特征图,第二阶段将图片特征和标签特征一起送入Transformer解码器中,图片特征作为key和value,标签特征作为query,将Transformer输出的query特征经过自适应特征池化和线性投影后预测标签存在性。

 

对于第一阶段,骨干网络作为一个特征提取器是可以自由替换的,可以使用CNN-based网络,也可以使用ViT等Transformer-based。

 

对于第二阶段,自适应池化和线性投影都是很常见的操作,一般是通过GlobalAvgPool+FC来实现。

 

Query updating

 

 

不同于大部分ViT-like模型使用Transformer编码器模块,本文使用的是解码器结构,每一层解码器模块中包含了一个自注意力模块,一个交叉注意力模块,和一个带位置编码的前馈网络。

 

通过初始化可学习的参数来学习每个标签的特征向量,在计算自注意力模块时,query,key,value三个值都是标签特征,而在交叉注意力模块中key和value时图片特征,而query是标签特征。

 

解码器结构对于多标签分类任务有很多好处,首先是自注意力模块全部计算标签特征,能学到标签之间的相关性,而交叉注意力模块使每一个标签的特征能自适应地与图片特征匹配。而独立为每个标签建立可学习参数这一做法,使得每个标签特征语义十分明确。

 

最终标签特征经过Transformer层输出的特征向量,直接通过线性投影即可得到对应的logits,进行Sigmoid+LossFunction监督。

 

Loss Function

 

损失函数部分,本文采用了一个简化版的ASL,如前文改进损失函数中所述,在Focal Loss上对正负样本采用不同的调节权重,最终取得了比BCELoss和Focal Loss更好的效果。

 

实验

 

ASL作为上一个SOTA,采用的Backbone是TResNetL,为了跟ASL的结果进行公平对比,本文也基于TResNetL进行了实验。由于Backbone可以随意替换,本文又实验了其他更强的Backbone,在MS-COCO数据集上结果如下:

 

可以看到经过了ImageNet22k预训练的模型可以取得更高的性能,但同等条件下横向对比,Query2Label方法超越了ASL等其他方法。

 

如果对更多的实验结果感兴趣可以自行查阅原文,在这里就不一一贴出了。

 

不同尺寸目标

 

由于不同目标在图中的尺寸不同,本文进行了更详细的实验对比,将小于32×32的目标视为小目标,小于96×96之间的为中目标,大于的为大目标,与Baseline相比在所有尺度上均有优势。

 

但是我注意到这里对比的是基于TResNetL的Baseline而非ASL,可能是由于相较于ASL的优势不那幺明显。

 

可视化

 

通过对交叉注意力图进行可视化,我们可以看到不同标签特征可以很好地捕捉到对应目标。

 

而跟Baseline的对比可以发现,Query2Label方法的注意力区域更加集中、更加准确,引入了更少的无关背景。

 

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。