Press "Enter" to skip to content

精度更高,速度更快!锚点 DETR:基于 transformer 目标检测的查询设计(AAAI 2022)

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

本文转自 旷视研究院。

 

 

● 简介   ●

 

近年来,以 DETR[1]为代表的基于 transformer 的端到端目标检测算法开始广受大家的关注。这类方法通过一组目标查询来推理物体与图像上下文的关系从而得到最终预测结果,且不需要 NMS 后处理,成为了一种目标检测的新范式。

 

但是,这类方法尚有一些不足之处。

 

首先,DETR 解码器的目标查询是一组可学习的向量。这组向量人类难以解释,没有显式的物理意义。同时,目标查询对应的预测结果的分布也没有明显的规律,这也导致模型较难优化。

 

为了解决上述问题,本文提出了一种基于锚点的查询设计,因此目标查询有了显式的物理意义,且每个查询仅关注对应锚点附近的区域,使得模型更容易优化。

 

此外,本文还提出了一种 attention 结构的变种,可以显着降低显存消耗,且对于检测任务中较难的 cross attention 依旧能保持精度不降。

 

如表 1 所示,最终本文算法比 DETR 精度更高,消耗显存更少,速度更快,且收敛更快(所需训练轮次更少)。

 

 

表1

 

● Attention 回顾 ●

 

首先,我们回顾一下 DETR 中 attention 的形式:

 

这里 Q、K 和 V 分别为查询、键和值,下标 f 和 P 分别表示特征和位置编码向量,标量
为特征的维度。实际上,Q、K 和 V 还会分别经过一个全连接层,这里为了简洁省略了这部分。

 

DETR 的解码器包含两种 attention,一种是 self-attention,另一种是 cross-attention。

 

在 self-attention 中,


一样,

一样。其中
由上一个解码器层的输出得到,第一个解码器层的
被初始化为一个常数向量,如零向量;而
设置为一组可学的向量,为解码器中所有的
共享:

 

在 cross-attention 中,
由之前的 self-attention 的输出得到;而

是编码器的输出特征;
是编码器输出特征对应的位置编码向量,DETR 采用了正余弦函数来作为位置编码函数,我们将该位置编码函数记作
,若编码器特征对应的位置记作
,那幺:
在此解释一下,H, W, C 分别是特征的高、宽和通道数目,而
是预设的目标查询数目。

 

● 查询设计 ●

 

通常我们把解码器中的
认作是目标查询,这是因为它负责分辨不同的物体(解码器中的初始
为零向量没有分辨能力)。

 

如前文所述,DETR 中的目标查询
是一组可学向量,其难以解释且没有显式的物理意义。观察这些目标查询对应的预测结果的分布,如图 1 所示,每个方格中的点表示一个目标查询对应的所有图像预测结果的中心点,可以看到,每个查询都负责非常大的范围,且导致负责的区域有很大的重叠,这种模糊性也使得网络很难优化。

 

 

图 1

 

为了解决这个问题,本文提出基于锚点的查询设计,每个目标查询为锚点坐标的编码,因此具有了显式的物体意义。并且,每个查询仅关注锚点附近的区域,可使得网络模型更易优化。

 

在基于 CNN 的 检测算法中,锚点通常都是特征网格点的坐标。而在本文中,锚点可以更加灵活。可以使用预设的网格位置的锚点,也可以是一组可以随网络学习的位置点。如图 2 所示,我们发现最终学习到的锚点分布与网格点较为相似,都是趋于均匀分布在整个图像上。这可能是因为在整个图像集中,图像的各个位置都会出现物体。

 

 

图 2

 

记锚点为
, 其表示有
个锚点,每个锚点记录点的(x,y)坐标。那幺,基于锚点的目标查询则是:
即目标查询为锚点坐标的编码。 那幺如何选择位置编码函数呢 ? 最自然地,本文选择与键特征共享一样的位置编码函数:
其中,g 为位置编码函数,它可以是前述的
,也可以是其它的形式。在本文中我们对启发式的
额外加入了两个全连接层以更好地调整它。

 

更进一步考虑,有时一个位置可能会出现多个物体。显然,若一个锚点仅能预测一个物体的话,那幺该位置的其它物体则需要其它位置的锚点来一同预测。这导致每个锚点负责的区域扩大,增加了其位置模糊性。为了解决这个问题,本文对每个锚点加入多种模式,使其可以有多个预测。

 

回顾 DETR,其中初始的查询特征为
, 对于
个目标查询来说,每个都只有一种模式
,其中   表示目标查询的索引。

 

因此,本文为每个目标查询设置多种模式
,其中
为模式的数目,是一个较小的值,如
=3。具体而言,本文使用一组可学向量
作为目标查询的多种模式。考虑移动不变性,我们希望这些模式与位置无关,因此让各个锚点共享多种模式。如此,我们便可得到增广的初始查询特征
和查询位置编码

 

观察改进后的目标查询对应的预测结果的分布,如图 3 所示,其中最后一行为锚点,前三行是对应锚点的三种模式的预测,可以看到,基于锚点的查询将关注锚点附近的区域,查询对应的预测框中心点都分布在锚点周围。此时查询不需要预测离对应锚点很远的物体,因此其具体特定的语义,从而模型将更容易优化。

 

 

图 3

 

图 4 展示了各个查询模式对应预测的分布,可以发现模式与物体大小存在一定关系,例如大物体几乎都出现在模式(a)中,模式(b)则关注小物体,模式(c)介于两者之间。另外我们还可以发现,所有的模式都会预测小物体,这是因为小物体更容易出现一个位置多个物体的情况。

 

 

图  4

 

● Attention 变种 ●

 

目前许多的 attention 变种,如 Deformable DETR[2]、Efficient Attention[3]等,都可以大幅度降低 transformer 占用的显存。然而,也许是由于 DETR 类方法中 transformer 解码器的 cross attention 较难,若使用同样的特征,这些方法将会导致一定程度上的精度降低。

 

本文提出了一种行列特征解耦的 attention 变种(Row-Column Decoupled Attention, RCDA),将键特征解耦为列特征和行特征,再依次进行列 attention 和行 attention。该方法不仅可以降低显存消耗,还可以得到和原先的标准 attention 相似或者更高的精度。

 

首先,对于键特征
,先将其解耦为行特征
和列特征
,本文采用的解耦方式为分别沿着列和行做均值。

 

接下来,则可以分别计算查询对于行、列键特征的注意力图:

 

其中,

 

最后则依据行列注意力图,对值特征依次沿着行、列进行加权和。不失一般的,我们假设 W≤H,可如下式先沿着列加权,再沿着行加权(若 W>H,则可先沿着行加权,使中间结果的显存占用小一些):

 

 

其中
行列解耦的 attention 变种的原理上文便介绍完了,现在我们再来讨论一下它为什幺可以节省显存。

 

在之前的表述中,我们不失一般的假设 Attention 头的数目为 1 以更加简洁,现在我们设其为 M。在标准的 Attention 中,注意力图
为主要的显存占用瓶颈,而在行列解耦的attention中,行列注意力图

的显存远小于标准 attention 中的注意力图。

 

由于特征的通道数目 C 通常大于 M,RCDA 的中间结果 Z 的显存占用要大于行列注意力图,因此我们主要比较 RCDA 的中间结果
与 标准Attention 中注意力图
之间的关系。显然,随着图像特征分辨率的增大(H 与 W 增大),标准  attention  的显存占用增长得更快。

 

行列解耦  attention  较标准  attention  可以节省显存的倍数为:

 

在默认的设置中,M=8,C=256,因此当特征长边 H 大于 32 时,RCDA 可以节省显存。在目标检测任务中,特征边长 32 是 C5 特征的一个典型值,因此使用 C5 特征显存占用相差不大,使用更大的 C4 特征显存可省 2 倍,依次类推。

 

● 总体流程 ●

 

算法的总体流程如图 5 所示,首先通过 CNN 网络提取图像特征,然后再经过transformer 编码器通过 self attention 处理图像特征,输出的图像特征将作为解码器的键和值特征。解码器的查询为前文所述基于锚点的多模式查询,在解码器中,各个查询分别根据注意力图聚合感兴趣的图像特征,最后输出最终的预测结果。预测框的中心点预测相对锚点的偏移量,而框的大小则预测其相对图像的大小。编码器和解码器中的 attention 可以采用标准的 attention,也可以采用本文所述的行列解耦 attention。对于 attention 中各特征的位置编码,则依据其位置使用共享的位置编码函数得到。

 

 

图 5

 

● 实验分析 ●

 

如表 2 所示,我们比较了本文算法与其它一些算法的性能比较,默认的骨干网络为 ResNet50。可以看到本文算法可以到达较好的性能,且继承了 DETR 无需手工设计锚框、无需 NMS 后处理,且不涉及随机内存访问的优秀性质。

 

 

表 2

 

不涉及随机内存访问(RAM-free)的性质可以减小硬件的访存代价,在实用中对硬件更加友好。

 

举个例子,假设有个人(专用计算芯片)力气很大(算力很强),他可以轻松地把一叠共 1000 张纸搬到指定的地方(计算处理某个张量)。而假如让他取出其中的第 123 张和第 234 张纸搬到指定的地方,需要搬的纸虽然少了很多(计算量大幅度降低),但是由于需要找到这些指定的纸(随机内存访问),可能会更加费时(访存代价增加)。

 

通常来说,两阶段的检测算法由于感兴趣的区域(RoI)的坐标对硬件来说随机的,提取感兴趣区域的特征会涉及到随机内存访问。而 Deformable DETR 也涉及到提取特定坐标的特征的情况,因此也非 RAM-free。

 

如表 3 所示,我们还分析了上述所提各个模块的效果。首先,我们可以看到所提的查询设计,即将锚点(anchors)编码为解码器查询以及为锚点加入多种模式(patterns),可以将性能从 39.3 提升至 44.2,这个显着的提升表明了本文的查询设计较原 DETR 查询设计的优越性。我们还可以看到,将标准的Attention 替换为 RCDA 性能近乎一致,这表明 RCDA 可以无损地降低显存占用。还有一点比较有趣的现象,若我们为原 DETR 的查询加入多种模式,其性能没有明显的变化,我们认为这是因为 DETR 的查询与位置没有明显关系,不能获得解决“一个位置多个物体”问题的收益。

 

 

表 3

 

如 表 4 所示,我们比较了使用不同数目的锚点(anchor points)和模式(patterns)。 100 个锚点数目过少性能较低,而 900 个锚点性能与 300 个锚点相差仅 0.3,因 此我们默认使用 300 个锚点。 可以看到,为每个锚点设置多种模式,性能会有明显的提升。 另外,当预测结果的数目一致时,即保持锚点数目乘以模式数目的值不变时,多种模式的性能也比一种模式效果更好,这说明了多种模式的提升并非是因为预测的数目增加,而是本质更好。

 

 

表 4

 

如表 5 所示,我们比较了 attention 变种的效果。可以看到,Efficient Attention 虽然可以大幅度降低显存,但是由于 cross attention 较难,效果有明显下降。而本文的 RCDA 将显存占用从 10.5G 降低至 4.4G,而精度却没有明显变化。

 

 

表 5

 

● 总结 ●

 

本文提出了一个基于 transformer 的检测算法,其实现简单,且比 DETR 精度更高,消耗显存更少,速度更快,且收敛更快。

 

● 参考文献 ●

 

[1] Carion N, Massa F, Synnaeve G, et al. End-to-end object detection with transformers[C]//European conference on computer vision. Springer, Cham, 2020: 213-229.

 

[2] Zhu X, Su W, Lu L, et al. Deformable detr: Deformable transformers for end-to-end object detection[J]. arXiv preprint arXiv:2010.04159, 2020.

 

[3] Shen Z, Zhang M, Zhao H, et al. Efficient attention: Attention with linear complexities[C]//Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision. 2021: 3531-3539.

 

Be First to Comment

发表评论

您的电子邮箱地址不会被公开。