Press "Enter" to skip to content

在行为预测模型中缩短坐标系带来的差距:蒸馏得到高效精确的以场景为中心运动轨迹

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

arXiv上传于2022年6月8日的论文“Narrowing the coordinate-frame gap in behavior prediction models: Distillation for efficient and accurate scene-centric motion forecasting“,关于普林斯顿大学和谷歌Waymo的工作(intern?)。

 

近年来,行为预测模型激增,尤其是在自动驾驶应用中,其中表示运动智体的未来分布对于安全舒适的运动规划至关重要。在这些模型中,选择坐标系来表示输入和输出,大致可分为两类。以智体为中心(Agent-centric)的模型在以智体为中心的坐标系中转换输入并执行推理。这些模型在场景元素之间本质上是平移和旋转不变,在公共排行榜上表现最好,但其规模是智体和场景元素数量的二次方。以场景为中心(Scene-centric)的模型使用固定的坐标系来处理所有智体,具有在所有智体之间共享表征的优势,其计算与智体数量成线性比例。然而,这些模型必须学习场景元素之间的平移和旋转不变性,并且通常表现不如以智体为中心的模型。

 

这项工作开发了概率运动预测模型之间的 知识蒸馏(knowledge distillation) 技术,并应用这些技术来缩小以智体为中心模型和以场景为中心模型之间的性能差距。在公共Argoverse基准上以场景为中心的模型性能提高了13.2%,在Waymo开放数据集上提高了7.8%,在大型Waymo内部数据集上提高了9.4%。这些改进的以场景为中心模型在公共排行榜上排名很高,在繁忙场景中,其效率是以智体为中心老师模型的15倍。

 

预测真实驾驶场景中多个车辆、自行车和行人的未来行为对自动驾驶车辆的安全舒适运动规划来说是一项困难的但必不可少的任务。此任务通常称为“运动预测”或“行为预测”。这是一个挑战,原因有很多。(1) 世界状态是多类构成的,包括静态和动态道路网络元素以及动态智体状态的观测。(2) 预测结果在很大程度上取决于多智体交互。(3) 由于潜在的智体意图,对多个未来的输出分布具有内在的不确定性和高度的多模态性。 如何表示输入世界状态、交互和输出分布 都是开放的问题和活跃的研究领域。

 

在过去几年中,解决这些建模挑战的行为预测系统激增。最有趣的设计选择之一,是表示输入和输出数据的坐标系。有两种不同的共同选择,其各有优缺点。

 

以智体为中心 的模型,在以智体为中心的坐标中表示输入和内部状态,并在此框架中执行推理。道路元素(如车道、人行横道)的坐标和其他智体的状态以相对于智体的姿态来描述,因此该表示对智体的全局位置和方向具有固有的不变性。这可以被视为一种特征预处理形式,允许模型专门针对于一个智体的视角,实践中在公共基准上取得最先进的性能。

 

然而,当对场景中的多个智体进行建模时,一个关键的缺点变得很明显:每个智体都是独立建模的,因此计算通常对智体数量是线性的,而对交互进行建模是二次的——对n个智体和m个道路元素的场景,计算规模为O(n(n+m))。对于公共基准而言,这不是一个问题,因为公共基准要求一次对不到十个智体进行建模,但对于由数百个智体组成的繁忙现实城市环境而言,这是一个计算瓶颈。

 

另一方面, 以场景为中心 的模型在一个共享的固定坐标系中为所有智体主要世界的状态进行编码。在此框架中所运行的模型通常是“自顶向下”或“BEV”表示,将世界离散为空间网格单元,并应用CNN主干对场景进行编码——尽管也存在以非光栅的场景为中心方法。经过这样的处理后,这些模型的预测头做全局到局部的转换,对智体中的轨迹进行解码。该方法的一个显着优点是,计算主要是空间网格分辨率和视野、而不是智体数的函数——大小为H×W单元的空间网格其规模为O(HW+n),其中第一项用CNN处理,并在实际设置中影响第二项量化。

 

这种方法的缺点是:(1)将世界状态离散为光栅格式时会丢失信息,(2)难以建模与CNN的远程交互,以及(3)模型必须学习旋转/平移不变性,或者在解码时学习对每个代理执行全局到局部的转换。

 

简言之, 以智体为中心 的模型优于公共排行榜上的以场景为中心的模型,上述缺点可能是可以解释的。然而, 以场景为中心 的模型是引人注目的,因为相对于场景中智体数,这是一种可以分摊计算的次线性规模;尤其适用于密集的城市环境。

 

知识蒸馏 是计算机视觉和自然语言处理(NLP)等领域中一种流行且有效的机器学习技术,用于将知识从一个大模型即“教师模型”转移到一个小模型即“学生模型”。最初为分类任务提出的 知识迁移 (knowledge transfer)机制,用教师模型的预测(“软标签”)取代了训练数据的真值(“硬标签”)。直觉是,与原始数据相比,这些软标签包含了一个信息更丰富的平滑目标空间,学生模型可以从中学习。

 

目前蒸馏概念已超出分类范围,扩展到序列预测任务,如神经机器翻译。然而,蒸馏从未应用于行为预测/运动预测领域。虽然行为预测可以被视为一个序列问题,但关键的区别在于,预测的未来分布期望能够准确地覆盖整个结果空间,而典型的NLP任务旨在生成单个真实输出。因此,在教师和学生之间迁移运动预测的知识是一个开放的问题。

 

此外,对于运动预测,如果未来被表示为包含意图模式的轨迹分布,轨迹和模式多样性至关重要。那幺,一个关键的挑战是蒸馏可能有害于多样性;这在NLP领域中已经进行了研究。

 

现在定义预测问题如下。设x为场景中所有智体的观察值(过去轨迹)和其他上下文信息(如车道语义和红绿灯状态),t为离散时间步长,st为智体在时间t的状态。未来轨迹s=[s1,…,st]是智体在时间t之前的状态序列。假设模型预测K个轨迹,其中,每条轨迹是一系列预测状态sk=[sk1,…,skT]。

 

对于以智体为中心和以场景为中心的方法,都考虑了一类模型,其输出是预测轨迹周围的高斯分布:

 

建立预测轨迹的概率分布模型,可以解释为每个预测轨迹的“置信度”。将其和上面的分布结合,得到混合高斯模型如下:

 

可以简化假设,即给定世界状态的历史,时间步长是条件独立的,从而用有效的前馈神经网络。K的典型数字为K=10输出轨迹。

 

用智体为中心的坐标系模型( ACM )作为教师。以智体为中心的模型从每个智体的角度对世界进行编码、处理和推理。这种表示需要将所有场景信息从全局坐标系转换为智体坐标系。因此,用智体为中心的方法,推理时间和内存需求会随着智体数量的增加而增加。

 

ACM架构用以下四种类型的输入:道路图和红绿灯信息、运动历史(即智体状态历史)和智体交互。对于道路图信息,ACM使用多段折线通过MLP(多层感知器)对3D高清地图中的道路元素进行编码。对于红绿灯信息,ACM用单独的LSTM作为编码器。对于运动历史,ACM使用LSTM来处理一系列过去的观测,隐状态的最后一次迭代用作历史嵌入。

 

对于智体交互,用LSTM在以智体为中心的框架对附近车辆的运动历史进行编码,并通过最大池化聚合所有附近车辆的信息,实现单个交互。这是一种全连接的相邻车辆交互建模的简单形式;比如GNNs、和/或注意机制或最大池化。最后,将这四种编码串联在一起,以便智体为中心的坐标系中为每个智体创建嵌入。基于MLP的解码器将最终嵌入转换为GMM。

 

对于学生,用以场景为中心的坐标系模型( SCM )。在SCM体系结构中,输入数据在一个全局坐标系中表示,在所有智体之间共享。如上所述,该方法的优点之一是可以将场景作为一个整体进行处理,从而产生的有效推理对智体数不变。

 

SCM使用三种类型的输入。将道路信息表示为用语义属性增强的点,智体信息表示是从每个智体的朝向框中采样的点,交通灯信息也表示为用语义属性增强的点。SCM使用PointPillars编码器和2D卷积主干对所有这些输入点进行编码。最终的逐智体嵌入提取,是从特征地图中裁剪一个patch,然后映射到场景中智体的当前位置。注意,即使最终实现了逐智体嵌入,所有的上游处理都是针对整个场景一次完成。与ACM一样,基于MLP的解码器将最终的逐智体嵌入转换为GMM。

 

如图提供了SCM和ACM之间的推理速度比较:随着场景中智体数的增加,推理速度差异逐渐增大,这表明ACM的扩展性不好。

 

尽管SCM的推理速度很快,但观察到,总体表现不如ACM。在公共排行榜上看到这一趋势,其中以智体为中心的模式往往占主导地位。为了做到两全其美(快速的推理速度+良好的预测精度),用从执行较慢但预测准确的老师(ACM)那里蒸馏知识来提高执行较快但预测不太准确的学生(SCM)。

 

对于ACM和SCM,都进行训练,最大化记录的驾驶轨迹似然:

 

损失函数中的第一项拟合每个第k条预测轨迹的可能性(使最接近真实预测轨迹成为最可能轨迹),第二项只是标准的GMM似然拟合做时间序列扩展。这样训练网络的最大优势是,避免了执行EM过程的需要,避免了直接拟合GMM似然的困难性。

 

如图所示是知识蒸馏的框架:左侧,这个教师和以智体为中心的模型被重复且独立地应用于场景中的每个智体,所有模型输入和输出都在每个智体自车为中心的坐标系中表示。右侧,这个学生和以场景为中心的模型一次应用于整个场景,而不需要每个智体重复计算。虽然速度更快,但以场景为中心的公式往往不太准确,因为必须理解和建模在智体为中心的方法中存在的智体不变性。为了利用以场景为中心的方法的计算效率,并从以智体为中心方法的准确性中获益,提出了一种知识蒸馏方法,用以智体为中心的或一个教师模型的预测轨迹来训练以场景为中心的或一个学生模型。

轨迹集蒸馏

在这种蒸馏方法中,训练学生模型以匹配老师模型的完整轨迹集输出。回想一下,模型的完整输出表示是一个GMM;忽略协方差,取每个分量的模态,就得到轨迹集。该轨迹集的权重由神经网络输出π给出。

 

蒸馏损失分为两部分:第一部分,用教师的预测轨迹(全部K条)作为多条伪真实轨迹来训练学生;这里希望第k个教师轨迹在学生学习的对应第k个模态分布是最大似然的;第二部分,施加交叉熵损失,鼓励学生的轨迹模态分布与教师的模态分布相匹配。

 

最后的损失函数是:

轨迹样本蒸馏

作为用多个轨迹作为伪真值的替代方法,从教师的分布中采样单个轨迹,作为学生分布的真值。这个称为 智体真值标签 。

 

然后,直接在此智体真值(而不是真值的标签)上进行优化。从数学上来说表示如下

 

根据预期,在无限样本上,此损失等于要求学生的加权轨迹集与教师的轨迹集相匹配。虽然非常简单,但与Lbase(θ)相同,不鼓励教师和学生的完全GMM分布相匹配。

 

最后一个损失的定义直接鼓励学生的全部GMM输出与教师的GMM匹配。与轨迹集提取一样,强制教师和学生之间的第k条轨迹(对于所有k)对应,以避免学生解空间中的排列模糊性。为匹配分布,用学生和教师的离散模态分布(π和∏)之间的交叉熵损失,以及轨迹序列中每个高斯分布(学生为Nt,教师为Nt)的KL发散度:

 

实验结果如下:WOMD = Waymo Open Motion Dataset

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注