Press "Enter" to skip to content

用双流网络也能学的又快又好?哈工大&微软提出用于视觉语言理解的蒸馏双编码器模型,在多个多模态…

关注公众号,发现CV技术之美

 

 

 

本篇分享论文 『Distilled Dual-Encoder Model for Vision-Language Understanding』 ,用双流网络也能学的又快又好?哈工大&微软提出用于视觉语言理解的蒸馏双编码器模型,在多个多模态任务上实现又快又好的效果!

 

详细信息如下:

 

 

论文地址:https://arxiv.org/abs/2112.08723

 

代码地址:https://github.com/kugwzk/Distilled-DualEncoder

 

       01       

 

摘要

 

本文提出了一个跨模态注意力蒸馏框架来训练用于视觉语言理解任务的双编码器模型,例如视觉推理和视觉问答。双编码器模型比融合编码器模型具有更快的推理速度,并且能够在推理过程中对图像和文本进行预计算。然而,双编码器模型中使用的浅交互模块不足以处理复杂的视觉语言理解任务。

 

为了学习图像和文本的深度交互,作者提出了跨模态注意力蒸馏,它使用融合编码器模型的图像到文本和文本到图像的注意力分布来指导双编码器的训练模型。此外,作者表明 ,在预训练和微调阶段应用跨模态注意力蒸馏可以实现进一步的改进。实验结果表明,蒸馏后的双编码器模型在视觉推理、视觉entailment和视觉问答任务方面取得了有竞争力的性能,同时比融合编码器模型具有更快的推理速度。

 

       02       

 

Motivation

 

视觉语言(VL)预训练模型学习了大规模图像-文本对的跨模态表示,并且可以直接微调以适应到各种下游 VL 任务,例如视觉语言理解/分类(视觉推理、视觉问答等)和图像文本检索。基于跨模态交互的方法,这些模型可以分为两类。

 

第一类是 融合编码器模型 ,它采用有效但较少高效的Transformer编码器,用于捕获具有跨模态注意力的图像和文本交互。该类别的大多数模型依赖于现成的目标检测器来提取图像区域特征,这进一步阻碍了它们的效率。最近,ViLT放弃了检测器,并使用 Vision Transformer 直接对图像patch进行编码。

 

它在提高效率的同时,在 VL 理解和检索任务上取得了有竞争力的表现。然而,由于需要同时编码图像和文本,基于 Transformer 的跨模态交互仍然是效率瓶颈,限制了其在具有大量图像或文本候选的任务中的应用。

 

第二类作品,包括 CLIP和 ALIGN,采用 双编码器架构 分别编码图像和文本。跨模态交互通过浅层融合模块建模,通常是多层感知器 (MLP) 网络或点积,与融合编码器模型中的 Transformer 编码器相比,它非常轻。此外,分开的编码支持离线计算和缓存图像和文本候选,这可以很好地扩展到大量候选。

 

这些变化在理解和检索任务中降低了更快的推理速度,使模型在现实生活中变得实用。双编码器模型在图像文本检索任务上取得了可喜的性能。双编码器模型在图像文本检索任务上取得了可喜的性能。然而,它们在需要复杂的跨模态推理的视觉语言理解任务上远远落后于融合编码器模型,例如 NLVR2。

 

在这项工作中,作者提出了一个跨模态注意力蒸馏框架来训练双编码器视觉语言模型。蒸馏后的双编码器模型在视觉语言理解任务中实现了具有竞争力的性能,其推理速度比融合编码器模型快得多。

 

除了软标签蒸馏,作者还引入了跨模态注意力蒸馏作为双编码器模型(学生)的细粒度监督,以更好地学习跨模态推理。具体来说,使用来自融合编码器模型(教师)的图像到文本和文本到图像的注意力分布进行蒸馏。

 

本文的蒸馏框架可以应用于预训练和微调阶段。在预训练期间,将蒸馏目标应用于图文对比学习和图文匹配任务。在微调阶段,将微调后的教师模型的特定任务知识转移到学生模型中。

 

作者在视觉语言理解任务和图像文本检索任务上评估本文的模型。实验结果表明,蒸馏的双编码器模型在视觉entailment、视觉推理和视觉问答方面具有竞争力,同时推理速度比融合算法快 3 倍以上。编码器教师模型。

 

此外,本文提出的跨模态注意力蒸馏还提高了检索任务的性能,甚至在图像检索方面优于教师模型。与其他潜在特征相比,跨模态注意力有助于双编码器模型学习更好的跨模态推理能力,在 VL 理解任务中取得显着收益。此外,两级蒸馏的模型比单级蒸馏的模型具有更好的性能。

 

       03       

 

方法

 

 

上图展示了本文的用于训练双编码器模型的跨模态注意力蒸馏框架。作者采用融合编码器模型作为教师,并引入跨模态注意力知识和软标签来训练双编码器学生模型。蒸馏目标适用于预训练和微调阶段,并帮助双编码器模型学习不同模态的交互。

 

3.1 Model Overview

 

本文的蒸馏框架可以使用不同的融合编码器模型作为教师。在这项工作中,本文采用 ViLT作为教师模型进行实验,因为它简单高效。

 

Input Representations

 

给定一个图像-文本对 (v, t) 作为输入,将图像分割成patch,其中是patch的数量, (H, W) 是输入图像分辨率,(P, P) 是每个patch的分辨率,C 是通道数。

 

输入文本 t 被 WordPiece标记为 M 个子词token的序列,就像在 BERT 中一样。然后,将特殊token和分别添加到图像patch和文本子词token序列中。

 

线性投影图像patch以获得patch嵌入,最终的视觉输入嵌入通过以下方式计算:

 

其中是线性投影,是可学习的 1D 位置嵌入,是视觉类型嵌入,是文本输入嵌入。

 

文本输入嵌入是通过将词嵌入、文本位置嵌入和文本类型嵌入相加得到的:

 

将作为教师和学生模型的视觉和文本输入。

 

Teacher: Fusion-Encoder Model

 

输入表示和concat为,然后将向量馈送到 L 层跨模态 Transformer 编码器以获得上下文表示:

 

 

其中。跨模态 Transformer 编码器通过多头注意力机制融合不同模态的表示。具体来说,对于第 l 层的每个头a,,注意力分布通过以下方式计算:

 

 

其中查询和键是通过分别使用参数线性投影上一层的隐藏状态来获得的。是注意力头大小。最后一层的token的输出向量被馈送到特定于任务的层以获得预测。

 

Student: Dual-Encoder Model

 

双模型通过基于视觉和文本 Transformer 的编码器分别对视觉嵌入 () 和文本嵌入 () 进行编码:

 

 

最后一层的token的输出向量被用作图像和文本的最终表示。作者采用浅层模块 f 来融合这两种表示。对于 VQA 等视觉语言理解任务,模块 f 是一个 MLP 网络。对于图文检索,使用点积函数来获得图文对的相似度分数。

 

3.2 Distillation Objectives

 

Cross-Modal Attention Distillation

 

为了改进双编码器模型以捕获图像和文本的更深层次的交互,作者利用融合编码器模型的跨模态注意力知识来指导双编码器模型的训练。具体来说,作者使用图像到文本和文本到图像的注意力分布来训练双编码器模型。

 

融合编码器教师模型通过多头注意力机制捕获跨模态交。整个注意力分布可以分为两部分。作者使用 N 和 M 来表示图像和文本输入的长度。第一部分是单模态注意力(),它对相同模态的token内的交互进行建模。

 

第二部分是跨模态注意力,包括图像到文本的注意力分布()和文本到图像的注意力分布()。跨模态注意力分布捕获视觉和文本特征向量的交互。

 

由于双编码器的单独编码仅模拟相同模态token的交互,因此作者引入跨模态注意力蒸馏以鼓励双编码器模型模仿融合编码器模型的图像和文本对齐。双编码器模型的交叉模态(图像到文本和文本到图像)注意力分布计算如下:

 

 

其中是 selfattention 模块的视觉查询和键。是文本输入的查询和键。以相同的方式重新计算教师的跨模态注意力分布,而不是直接拆分原始的注意力分布。跨模态注意力蒸馏损失通过以下方式计算:

 

其中是 Kullback-Leibler 散度。本文只迁移了教师模型最后一层的跨模态注意力知识。

 

Soft Label Distillation

 

除了模仿跨模态注意力分布之外,作者还使用教师模型的预测作为软标签来改进学生。软标签损失计算如下:

 

 

其中分别是学生和老师的预测logits

 

3.3 Two-Stage Distillation Framework

 

本文使用提出的知识蒸馏目标在两阶段框架下训练双编码器学生模型,包括预训练蒸馏和微调蒸馏。在这两个阶段,融合编码器模型帮助双编码器模型学习跨模态交互。

 

 

如上表所示,作者根据任务的特点对模型进行不同目标的训练。

 

3.3.1 Pre-Training Distillation

 

在预训练期间,双编码器学生模型在大规模图像-文本对上进行训练,以学习具有图像-文本匹配、图像-文本对比和掩码语言建模任务的通用跨模态表示。预训练的融合编码器模型 ViLT用作教师模型。

 

Image-Text Matching (ITM)

 

图文匹配的目标是预测输入的图文是否匹配。在 ViLT之后,作者用 0.5 的概率替换匹配的图像来构建负对。作者在 ITM 输入对上使用跨模态注意力蒸馏损失和软标签损失来训练双编码器模型。

 

Image-Text Contrastive Learning (ITC)

 

作者通过batch内负采样引入对比损失,以优化视觉和文本表示的共享空间。给定一个batch的 N 个图像-文本对,可以获得 N 个匹配对和个负对。图像-文本对比学习旨在从所有可能的配对中预测匹配的配对。

 

融合编码器模型需要对每一对进行联合编码以获得软标签,这导致了二次时间复杂度。因此,作者只考虑在 N 个匹配对上计算的跨模态注意力分布。

 

Masked Language Modeling (MLM)

 

Masked Language Modeling的目标是从所有其他未mask的token中恢复mask token。作者使用 BERT 中 15% 的mask概率。为了提高训练速度,作者使用ground truth标签来训练 MLM 任务的模型。

 

3.3.2 Fine-Tuning Distillation

 

在微调过程中,作者使用微调后的 ViLT 作为教师模型,并对下游任务数据进行跨模态注意力蒸馏。

 

Vision-Language Understanding

 

对于视觉语言理解任务,例如视觉推理和 VQA,作者使用跨模态注意力蒸馏和软标签损失来微调学生模型。

 

Image-Text Retrieval

 

对于检索任务,作者在教师模型和ground truth标签的交叉模态注意力分布的监督下训练学生,以进行有效的训练。

 

       04       

 

实验

 

 

上表展示了本文方法中所用到的一些数据集。

 

 

上表展示了三个任务的微调结果。与以前的双编码器模型(如 CLIP)相比,本文的模型在三个视觉语言理解任务中取得了更好的性能,将平均得分从 57.83 提高到 73.85。从上表可以看出,在预训练和微调阶段执行蒸馏都对双编码器模型做出了积极贡献。与 ViLT 初始化的双编码器模型的直接微调相比,在微调期间使用跨模态注意力蒸馏带来了显着的改进。

 

 

除了视觉语言理解任务外,作者还在图像文本检索任务上评估了本文的方法。本文的双编码器学生模型经过跨模态注意力蒸馏和对比损失的训练。上表报告了在 Flickr30K 上微调的模型的结果。

 

本文的双编码器模型以更快的推理速度实现了具有竞争力的性能。该模型在图像检索方面甚至优于融合编码器教师模型 (ViLT)。此外,实验结果表明,跨模态注意力蒸馏也改进了检索任务的模型。

 

 

作者评估了本文的双编码器模型和 ViLT 在视觉语言理解任务上的推理延迟。这两个模型都在具有相同超参数的单个 P100 GPU 上进行评估。由于双编码器架构,作者的模型可以缓存图像表示以减少冗余计算。不同任务的平均推理时间和缓存时间如上表所示。

 

本文的双编码器模型在三个任务中实现了更快的推理速度。预计算图像表示进一步提高了推理速度,这对于现实生活中的大量图像和文本非常有效。

 

 

作者研究了蒸馏中使用的不同知识的影响。在微调期间对具有不同蒸馏损失的视觉语言理解任务进行了实验。双编码器学生模型由 ViLT 直接初始化。上表说明了跨任务的结果。

 

首先,可以发现使用软标签蒸馏比真实标签获得更好的性能。然而,使用软标签训练的模型在 NLVR2 任务上的准确率仍然相对较低。作者进一步结合了融合编码器模型的中间表示,以提高双编码器模型的性能。本文使用隐藏状态和不同的注意力分布进行比较。

 

在三个任务中,使用注意力分布比隐藏状态带来更多的改进。作者进一步探讨了注意力分布的哪一部分更为关键,包括跨模态注意力和单模态注意力。模仿教师的跨模态注意力分布比单模态部分取得了更多的改进,这验证了跨模态交互对于视觉语言理解任务更为重要。

 

作者还发现,仅使用跨模态注意力分布比使用整个注意力分布(跨模态 + 单模态)表现更好。

 

 

作者在教师和学生的最后一层执行所提出的知识蒸馏方法。为了验证仅在最后一层提取的有效性,将其与逐层策略进行比较。结果如上表所示。最后一层蒸馏策略在 NLVR2 和 SNLI-VE 任务上获得了更好的性能。此外,仅使用最后一层的注意力知识需要较少的计算。因此,仅使用最后一层是执行本文的跨模态注意力蒸馏的更实用的方法。

 

       05       

 

总结

 

在这项工作中,作者引入了一个跨模态注意力蒸馏框架来提高双编码器模型在视觉语言理解任务上的性能。采用融合编码器模型的跨模态注意力知识,包括图像到文本和文本到图像的注意力分布,来指导双编码器模型的训练。

 

实验结果表明,蒸馏后的双编码器模型在 NLVR2、SNLI-VE 和 VQA 上实现了具有竞争力的性能,同时具有比融合编码器模型快得多的推理速度。

 

参考资料

 

[1]https://arxiv.org/abs/2112.08723

[2]https://github.com/kugwzk/Distilled-DualEncoder

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注