Press "Enter" to skip to content

DeepMind提出关系RNN:构建关系推理模块,强化学习利器

基于记忆的神经网络通过利用长时间记忆信息的能力来建模时序数据。然而,目前还不清楚它们是否有能力利用它们记得的信息进行复杂的关系推理。

在这篇论文中,DeepMind和伦敦大学学院的研究人员首先证实一种直觉想法,即标准的记忆架构在一些涉及关系推理的任务上很困难。然后,研究者通过使用一个新的记忆模块——Relational Memory Core(RMC)——来改进这种缺陷,该模块采用multi-head dot product attention来允许记忆交互。

最后,研究者在一系列任务上测试RMC,这些任务可以从跨序列信息的更强大的关系推理中获益,并且在RL领域(例如Mini PacMan)、程序评估和语言建模中显示出巨大的受益,在WikiText-103、Project Gutenberg和GigaWord数据集上获得state-of-the-art的结果。

cb7beb3b7d2390118f234dedca608480181e7c18

关系记忆核心RMC

人类使用复杂的记忆系统来访问和推理重要的信息,不管这些信息最初是什么时候被感知到的。在神经网络研究中,许多成功的序列数据建模方法也使用了记忆系统(memory systems),例如LSTM和记忆增强的神经网络(memory-augmented neural networks)。通过增强记忆容量、随时间的有限计算成本以及处理梯度消失的能力,这些网络学会了跨时间关联事件,以便熟练地存储和检索信息。

在这里,我们建议在考虑存储和检索的同时考虑记忆交互,这是卓有成效的。虽然目前的模型可以学习划分和关联分布式的、矢量化的记忆,但它们并不明显地倾向于这样做。我们假设,这样的偏见可以让一个模型更好地理解记忆是如何关联的,因此可以让它更好地进行关系推理。

首先,我们通过开发一个演示任务来强调顺序信息的关系推理,证明当前的模型在这个领域中存在困难。使用新的关系记忆核心( Relational Memory Core,RMC),利用multi-head dot product attention让记忆彼此交互,我们解决并分析了这个问题。然后,我们将RMC应用到一系列任务中,这些任务可能会从更显式的memory-memory 交互中获益,因此,可能会增加随时间推移的的关系推理能力:在Wikitext-103、Project Gutenberg和GigaWord数据集中,部分观察到的强化学习任务、程序评估和语言建模。

关系推理(Relational reasoning)

我们认为关系推理是理解实体连接的方式的过程,并利用这种理解来实现更高阶的目标。例如,考虑对各种树与公园长椅之间的距离进行排序:将实体(树和长椅)之间的关系(距离)进行比较,以得到解决方案;如果我们单独考虑每个实体的属性(位置),则无法得到解决方案。

由于我们通常可以很流畅地定义什么构成“实体”(entity)或“关系”(relation),因此我们可以想象一系列的神经网络诱导的偏见,可以用关系推理的语言表达出来。例如,可以用卷积核来计算一个感受野内的实体(像素)的关系(线性组合)。

在时域(temporal domain)中,关系推理可以包含在不同时间点比较和对比信息的能力。这里,注意力机制隐式地执行某种形式的关系推理;如果先前的隐藏状态被解释为entity,那么使用注意力来计算实体的加权和有助于消除RNN中存在的局部性偏差。

由于我们当前的架构解决复杂的时序任务,因此它们必须具备一些时间关系推理的能力。然而,目前还不清楚他们的归纳偏差是否受到限制,以及这些限制是否可以暴露在要求特定类型的时间关系推理的任务中。

模型

我们的指导设计原则是提供一个架构的主干,在这个基础上,模型可以学习如何划分信息,以及如何计算划分的信息之间的交互。为了实现这一点,我们从LSTM、 memory-augmented神经网络和non-local网络(特别是Transformer seq2seq模型)组装构建块。与记忆增强架构相似,我们考虑一组固定的memory slots;但是,我们允许使用注意里机制在memory slots之间进行交互。与之前的工作相反,我们在单个时间步上在记忆之间应用注意力,而不是跨过在先前的观察中计算出来的所有先前的表征。

18c37843c986ec73cf2cf108102cca7e762dc71d

我们在一组监督学习和强化学习任务中测试RMC。值得注意的是Nᵗʰ Farthest的任务和语言建模。在前者中,解决方案需要显式的关系推理,因为模型必须对向量之间的距离关系进行排序,而不是对向量本身排序。后者在大量自然数据上测试模型,并允许我们将性能与经过良好调优的模型进行比较。

实验

这里简要介绍应用RMC的实验任务,具体每个任务的详细信息以及模型的超参数设置等请阅读原论文。

说明性监督任务

Nᵗʰ Farthest

第N个最远的任务是为了强调跨时间的关系推理能力。输入是随机抽样的向量序列,目标是对形式问题的回答:“距离向量m的第n个最远的向量是什么?”,其中向量的值、它们的ID、n和m都是每个序列随机抽样的。我们强调模型必须对向量之间的距离关系进行排序,而不是对向量本身。

程序评估

Learning to Execute(LTE)数据集由图灵完整的伪代码编程语言中的算法片段组成,可分为三类:添加、控制和完整程序。输入是表示这些代码片段的字母数字词汇表上的字符序列,目标是一个数字字符序列,它是给定编程输入的执行输出。考虑到这些片断涉及变量的符号操作,我们认为它可能会影响模型的关系推理能力;由于符号运算符可以被解释为在操作数上定义一个关系,成功的学习可以反映对这个关系的理解。为了评估经典序列任务的模型性能,我们还对记忆任务进行了评估,在这些任务中,输出只是输入的一种排列形式,而不是来自一组操作指令的评估。

强化学习

Mini Pacman with viewport

我们遵循文献[23]中的Mini Pacman的表述。简而言之, agent在被ghosts追赶时在迷宫中导航以收集食物。我们用一个视图(viewport)来实现这个任务:围绕agent的5×5窗口,包含感知输入。因此,任务是部分可观察的。agent必须预测记忆中ghosts的动态,并据此计划导航,同时也要根据被拾取的食物的记忆信息。 该任务要求在记忆空间中进行关系推理。

语言建模

最后,我们调查了基于词汇的语言建模任务。

结果

2e1c14b3068bd9251be1735e8fd013110e6bab61

图3:模型分析

每行描述了特定序列的每个时间步的注意力矩阵。下面的文本阐明了序列的特定任务,该序列被编码并作为输入提供给模型。我们用红色标记任务中引用的矢量。

409922d568bbcd83d8742de8f1847b49cdee0506

表1:测试程序评估和记忆任务的每个字符的准确性。

8f64406ae7f02bb0e7ac7ae286123b5b79bca19a

表2:WikiText-103、Project Gutenberg和GigaWord v5数据集上的验证和测试困惑度

总的来说,我们的结果显示,记忆交互的显式建模还提高强化学习任务,以及程序评估、比较推理和语言建模的性能,这表明在递归神经网络中加入关系推理能力的价值。

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注