Press "Enter" to skip to content

ECCV 2022|经典算法老当益壮,谷歌提出基于k-means聚类的视觉Transformer

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

原文链接:https://www.techbeat.net/article-info?id=3830

 

作者:seven_

 

 

论文链接:https://arxiv.org/abs/2207.04044

 

代码链接:https://github.com/google-research/deeplab2

 

近来视觉Transformer(ViT)的发展可谓是步入了一个全新的阶段,越来越多的任务和应用已经使用ViT替代CNN来作为视觉特征backbone,就像当年CNN代替SVM等传统方法进行图像分类一样,“长江后浪推前浪”的故事也在AI领域中上演。但是现有ViT结构基本源于自然语言处理领域(NLP)中的处理模式, 其主要由self-attention和cross-attention组成 ,目前这种结构其实忽略了 语言嵌入和图像像素之间的关键差异 ,具体体现在图像空间中 过长的像素序列 ,单一的像素块无法精确地涵盖具体视觉目标的所有信息,因而无法达到像文字embedding一样的交互注意学习效果。

 

基于以上考虑, 本文从视觉像素点的自身特点出发,提出将像素特征与具体query目标的交互注意力建模重新定义为一个聚类过程 ,并结合传统的k-means聚类算法,提出了一种用于图像分割任务的k-means Mask Xformer (kMaX-DeepLab),kMaX-DeepLab不仅性能达到SOTA,同时也展示了一种将传统方法融入到现有模型设计中的新思路。该论文来自霍普金斯大学和谷歌研究院,目前已被计算机视觉顶级会议ECCV2022接收。

 

一、 动机

 

ViT应用在视觉任务中的一大亮点是 以端到端的形式来解决复杂的视觉识别问题 。例如具有开创性的工作DETR[1]提出了第一个基于Transformer结构的端到端目标检测网络。DETR引入了一种新的对象建模范式,即首先通过CNN提取像素特征,随后部署几个Transformer编码器进行特征增强,再通过一系列的自注意力和交互注意模块来对特征进行聚合,最后经过一个前馈网络(FFN)对 query向量进行解码得到回归结果和分类结果 。

 

在DETR之后,MaX-DeepLab[2]的提出也进一步验证了ViT结构在更加复杂的全景分割任务上的有效性。但是本文作者认为直接在视觉任务中使用NLP中的Transformer结构是不合理的,因为其最初是根据语言任务的特性进行设计的,例如机器翻译[3], 输入序列和输出序列的长度是相同且固定的 ,这使得网络可以轻松的学习到 两种文字模态之间的一一对应关系 。但是迁移到视觉任务中,情况就会发生变化,对于检测或者分割任务,目前的具体做法是 使用较少的query向量(一般是128个查询向量)来拟合图像中感兴趣目标的区域 ,但输入图像可能包含有几千个像素。这使得在交互注意力提取过程中, 每个query向量都需要与几千个像素进行交互计算来定位目标 ,这带来了巨大的计算量,也会导致性能下降。

 

基于以上分析,本文作者重新审视了cross-attention与视觉任务特性之间的关系,作者发现 cross-attention的操作过程与传统的k-means聚类具有很强的相似性 ,作者将每个 可学习的目标query向量看做是k-means中的聚类中心 ,并根据k-means的中心迭代策略对交互注意机制进行了重新设计。从另一个角度来看,本文是MaX-DeepLab的升级版本,并且结合了传统k-means算法的聚类思想,虽然设计简单,但是其有效的提升了模型的分割性能。

 

二、本文方法

 

本文提出的kMaX-DeepLab首先考虑应用在全景分割任务中,因为全景分割可以很容易的推广到其他分割任务中。从视觉ViT的角度出发, 全景分割可以理解为对输入的N个目标query向量在图像上预测其对应的N个mask及其语义类别 ,形式化表示为:

 

{ y ^ i } i = 1 N = { ( m ^ i , p ^ i ( c ) ) } i = 1 N \left\{\hat{y}_{i}\right\}_{i=1}^{N}=\left\{\left(\hat{m}_{i}, \hat{p}_{i}(c)\right)\right\}_{i=1}^{N} { y ^ ​ i ​ } i = 1 N ​ = { ( m ^ i ​ , p ^ ​ i ​ ( c ) ) } i = 1 N ​

 

其中 p ^ i ( c ) \hat{p}_{i}(c) p ^ ​ i ​ ( c ) 表示对应mask的语义类别预测置信度,其中包括”things”类(事物类,可计数)、”stuff”类(背景类,不可计数)和”Void”类(空类)。

 

N个query向量最后经过一个Transformer解码器来聚合对应区域中的像素特征并得到最终的预测结果,解码器由一系列自注意力块和交互注意力块构成,本文方法主要对交互注意力块进行了改进,提出了k-means cross-attention机制,下面首先介绍一下cross-attention与k-means聚类之间的联系。

 

2.1 Cross-Attention与k-means聚类之间的关系

 

对于cross-attention来说,其主要目的是通过聚合经过编码器得到的像素特征来更新可学习的目标query向量,以便得到更有意义的预测结果,这一过程可以表示为:

 

C ^ = C + s o f t m a x H W ( Q c × ( K p ) T ) × V p \hat{\mathbf{C}}=\mathbf{C}+\underset{H W}{softmax}\left(\mathbf{Q}^{c} \times\left(\mathbf{K}^{p}\right)^{\mathrm{T}}\right) \times \mathbf{V}^{p} H W so f t ma x ​ ( Q c × ( K p ) T ) × V p

 

其中 C ∈ R N × D \mathbf{C} \in \mathbb{R}^{N \times D} R N × D 表示N个目标query, C ^ \hat{\mathbf{C}} C ^ 表示更新得到的目标query向量,使用 H W HW H W 表示softmax函数在空间维度上的作用范围。 这表示在对目标query进行更新时,softmax函数需要在整个输入图像上进行计算 ,这通常涉及到几千个像素,因此学习到的注意力图可能会经过多次训练迭代,尤其在训练的早期阶段,网络优化的负担会更重, 因为注意力图一般是随机初始化得到的,每个对象query很难在几千个像素中迅速找到最接近最合适的区域 ,这一点是与之前的机器翻译任务原理大相庭径的。

 

当作者开始考虑k-means聚类方法时,发现两者的整体流程基本上是类似的,都 遵循动态迭代的形式 ,只是cross-attention更新的目标是 对象的query向量 ,而k-means更新的是 聚类中心 ,其工作原理如下:

 

A = a r g m a x N ( C × P T ) C ^ = A × P \begin{array}{l} \mathbf{A}=\underset{N}{argmax}\left(\mathbf{C} \times \mathbf{P}^{\mathrm{T}}\right) \\ \hat{\mathbf{C}}=\mathbf{A} \times \mathbf{P} \end{array} A = N a r g ma x ​ ( C × P T ) C ^ = A × P ​

 

其中, C , P , A C,P,A C , P , A 分别表示聚类中心,像素特征和生成的聚类结果,通过比较cross-attention和k-means的表达式,我们可以发现 k-means聚类中心的更新方式是直接赋值,而不是cross-attention的残差更新方式 ,此外,k-means在生成注意力图(通过argmax函数计算更新特征的权重)时,其 计算范围是沿着聚类中心的维度 ( N ) (N) ( N ) 进行的,而不是cross-attention中的整个空间维度 ( H W ) (HW) ( H W ) 。

 

从聚类的角度来分析图像分割任务的话, 图像分割相当于将每个像素分配不同的聚类簇中,其中每一个簇对应一类预测的mask 。cross-attention的处理方式也是希望将像素分配给不同的对象query中,但是其计算规模太过庞大,也对网络的可分性能力提出了更高的要求。本文作者认为使用k-means中的argmax方式更为合理,进而提出了k-means Mask Transformer结构来进行分割,下文的实验也证明,k-means Transformer可以加速训练收敛,同时也具有更好的性能。

 

2.2 整体框架

 

kMaX-DeepLab主要由三个组件构成,如下图所示, 其中包含像素编码器、增强像素编码器和kMaX解码器 。像素编码器可以使用任意的CNN或者ViT backbone来提取视觉特征,增强像素编码器负责将得到的特征图进行上采样恢复到输入图像的高分辨率,同时根据transformer编码器计算自注意力特征,最后kMaX解码器从k-means聚类的角度将目标query向量(或者理解为聚类中心)转换为mask嵌入向量。

 

 

其中核心模块为本文提出的基于k-means的kMaX解码器,作者仅仅将原始transformer解码器中的cross-attention更换为本文设计的k-means cross-attention,详细构成如下图所示,红色框表示k-means cross-attention的操作细节,根据上述分析,作者将空间维度上的argmax替换成为k-means中的聚类中心维度argmax操作,就可以得到一个kMaX解码器,其表达式也变为:

 

C ^ = C + a r g m a x N ( Q c × ( K p ) T ) × V p . \hat{\mathbf{C}}=\mathbf{C}+\underset{N}{argmax}\left(\mathbf{Q}^{c} \times\left(\mathbf{K}^{p}\right)^{\mathrm{T}}\right) \times \mathbf{V}^{p} . N a r g ma x ​ ( Q c × ( K p ) T ) × V p .

 

 

在模型的具体实现中,本文作者给出了两个kMaX-DeepLab的具体实现版本,如下图所示,分别是以ResNet-50和MaX-S作为视觉backbone。 不过其实现的过程都可以分为两个路径,分别是像素路径和聚类路径 ,其中像素路径由一个像素编码器和增强的像素编码器构成,像素编码器可以直接使用预训练的网络,而增强的像素编码器由一系列的注意力块构成。**对于聚类路径,作者设置了六个kMaX解码器,其分别用来处理不同分辨率的特征图。**最后两个路径通过FFN进行连接输出最终的预测mask和预测类别。

 

 

三、实验效果

 

本文分别在COCO和Cityscapes数据集上进行了实验,这两个数据集均为大规模的全景分割数据库,其中COCO面向于人类日常生活场景,Cityscapes更侧重于城市道路驾驶场景。kMaX-DeepLab在COCO数据集上的实验效果如下表所示:

 

 

从表中可以看出,作者对不同backbone的模型结构进行了对比,其中包括ResNet-50、MaX-S、Swin Transformer以及ConvNeXt等具有代表性的backbone网络。可以发现,在以最轻量的ResNet-50上,kMaX-DeepLab可以达到53.0%的PQ值,这已经超过了目前其他SOTA方法,甚至超过了其他具有更大参数量backbone的方法。

 

此外作者还进行了可视化实验,下图展示了kMaX-DeepLab的六个kMaX解码器中的像素聚类效果,可以观察到, 在前三个解码器中,已经生成了合理的聚类结果,而在后三个解码器中主要完成的是细节学习。

 

 

在下图展示的另外一个示例中,我们可以看到kMaX-DeepLab也可以以 局部到整体的方式 来捕获图像中的感兴趣目标,图中的大象和人是通过六个解码器逐渐聚类得到的。

 

 

四、总结

 

本文对视觉Transformer的内部运行机制进行了探索,分析了现有结构在图像识别任务上的弊端,并提出从聚类的角度重新思考像素特征与目标query之间的关系,结合k-means聚类提出了一种端到端的全景分割模型,称为k-means Mask Transformer(kMaX-DeepLab)。kMaX-DeepLab使用k-means解码器来替换原有Transformer模型中的多头交互注意力块来简化模型,同时也提升了模型的分割效果。本文也从侧面印证了传统经典算法的思想在今天仍然适用,稍加改造和借鉴完全可以提升现代模型的综合性能。

 

[1] Carion, N., Massa, F., Synnaeve, G., Usunier, N., Kirillov, A., Zagoruyko, S.: End-to-end object detection with transformers. In: ECCV (2020)

 

[2] Wang, H., Zhu, Y., Adam, H., Yuille, A., Chen, L.C.: Max-deeplab: End-to-end panoptic segmentation with mask transformers. In: CVPR (2021)

 

[3] Sutskever, I., Vinyals, O., Le, Q.V.: Sequence to sequence learning with neural networks. In: NeurIPS (2014)

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注