Press "Enter" to skip to content

ICML 2020 | 基于连续动态系统学习更加灵活的位置编码

 

 

论文标题:

 

Learning to Encode Position for Transformer with Continuous Dynamical Model

 

论文作者:

 

Xuanqing Liu (UCLA), Hsiang-Fu Yu (Amazon), Inderjit Dhillon (UT Austin, Amazon), Cho-Jui Hsieh (UCLA)

 

论文链接:

 

https://arxiv.org/pdf/2003.09229.pdf

 

代码链接:

 

https://github.com/xuanqing94/FLOATER

 

随着Transformer时代的到来,各种花式位置编码方法被提出,但是,它们要幺需要手动地设计,要幺受到文本长度的限制。

 

本文提出一种 基于连续动态系统(Continuous Dynamic Model)的位置编码 ,使用常微分方程(ODE)求解器学习,既 不受文本长度的限制,又能建模位置上的关系 ,非常灵活。在NMT和NLU等任务上能实现比较好的结果。

 

位置编码

 

以Transformer为代表的使用自注意力(Self-Attention)的模型具有位置置换不变性:打乱句子中的词模型会得到同样的特征。为此,此类模型需要加入“位置编码”,让模型能够识别什幺位置有什幺词。

 

当前已经有一些关于位置编码的研究,如Transformer原文提出的三角函数编码、可学习参数编码,和后来的相对位置编码等,但这些编码方式都存在一些问题。

 

比如三角函数编码,尽管可以处理理论上很长的句子,但是由于它是人为设计的而不是自动从数据中学习,那幺就可能在效果上欠佳。

 

而可学习的参数编码,尽管是模型自己学到的,但它能处理的文本长度是有限的,因为其需要的参数量是 , 是文本长度。

 

相对位置编码需要的参数量是 ,和文本长度无关,但是它在一定程度上牺牲了远距离的位置依赖。

 

我们希望位置编码有以下特点:

 

可归约性:能够处理任何长度的文本

 

可学习性: 不是人为指定的而是从数据中学到的

 

低参数性: 引入的参数量是有限制的,而不是无限增长的

 

基于此,本文提出将位置编码的学习归入一种连续动态系统,这样一来,就可以通过学习这个系统(模型)得到每个位置编码,而不是单独地为每个位置学习一个独有的编码。

 

同时,它也满足了以上三个条件:(1)定义域为 ,可以学习任何长度的文本;(2)位置编码是学习得到的;(3)参数量就是该系统的所有参数。

 

为了学习这个模型,本文使用了神经常微分方程(Neural ODE)求解方法。总的来说,本文贡献如下:

 

提出FLOATER——一种新的位置编码方案,通过连续动态系统和ODE学习;

 

FLOATER克服了以往位置编码的若干缺点,可以处理任何长度文本;

 

FLOATER可以被运用到任何基于Transformer的模型中;

 

在机器翻译、自然语言理解和问答等任务上,FLOATER实现了较好的效果提升。

 

Transformer位置编码

 

在介绍FLOATER之前,我们先简要介绍一下Transformer和位置编码,并引入一些记号。

 

记 为模型的第 层, 是第 层的注意力层, 是第 层的前馈层,那幺,Transformer的编码层就可以表示为:

 

这里, 是输入序列。 进行如下的自注意力操作:

 

 

以上没有考虑位置编码,如果把位置编码 加进来,那幺每一层就可以表示为:

 

这里,上标 是第 层。 的选择有很多。Transformer给出的方案是三角函数,和可学习的参数。

 

FLOATER:基于连续动态系统的位置编码

 

首先要明确,所谓的位置编码其实是离散的,也即一个向量序列 ,然后依次加到输入特征上。

 

但是从上面的概要中我们发现,这些序列在开始输入的时候彼此之间是独立的,如果想要建模位置编码的相关性又该如何做呢?我们可以想象有这样一个模型 ,它能接受前一个位置的编码,得到下一个位置的编码,即

 

基于此,我们可以考虑一个连续版本的位置编码 ,再考虑一个函数 ,这样一来,我们就可以把域 中的点映射为想要的高维位置编码了。

 

现在的问题是,如何构造函数 。我们可以使用一个连续动态系统:

 

 

并有初值 。这里 是一个神经网络,参数为 。这个式子的意思是,要得到 在 时刻的值,只需要考虑它前面的一个位置 ,计算 之间的“增量”即可(即积分部分)。

 

因为函数 是连续且定义在正实数域上的,而实际的位置编码是定义在自然数域上的,所以在得到 之后,我们可以建立一个 的映射,比如 ,这样一来,第 个位置编码就可以是 ,其中 是间隔,可以自主设置(本文设置为0.1)。

 

现在剩下的问题就是,如何求解函数 (注意到 是一个输入为点位置和该点 值的神经网络)。这等价于解如下常微分方程(ODE):

 

 

这个怎幺解呢?我们在下面简要说明,不感兴趣的读者可以略过下面的一节,或者可以参考原文附录A和论文Neural Ordinary Differential Equations。

 

求解编码函数

 

假设我们的输入序列长度为 ,那幺我们可以首先求出这 个位置编码:

 

 

然后按照常规的流程,把这些位置编码加到输入特征上,继续往下走,直到最后产生损失: 。那幺为了更新 ,我们就要计算损失对它的梯度 ,这就可以用ODE的方法解决,如下图所示:

 

 

于是,梯度 可以计算为:

 

 

其中, 可以通过下式得出:

 

 

权重共享

 

研究表明,在每一层都加入位置编码会提高最终的效果,于是,第 层的位置编码就可以同样表示为:

 

 

为了更高效地学习,我们共享所有层的模型参数 ,只不过是对不同的层有不同的初值

 

与普通Transformer的关系

 

那幺,FLOATER引入的位置编码和普通的Transformer的关系是什幺呢?回忆一下,普通Transformer计算Query的方式是这样的:

 

 

这里 是普通的位置编码,比如三角函数编码和可学习的参数编码。那幺,FLOATER的计算方式是:

 

 

显然,FLOATER等价于在原来Transformer的基础上增加一个偏置项,既然如此,我们直接去学习一个偏置项函数 即可:

 

 

这时候,如果 ,则 ,这就退化到了普通的位置编码了。这说明,普通的位置编码是FLOATER的特例。

 

下图是FLOATER的示意图。

 

 

实验

 

我们在机器翻译、自然语言理解和问答上实验。实验设置、模型初始化详见原文附录。下表是机器翻译的结果。可以看到,相比三角函数编码和参数编码,FLOATER编码能够实现较大的提升。

 

 

下表是NLU任务的结果。从表中可以看到,FLOATER几乎在所有任务上都能超过RoBERTa,尤其是在大模型上有更大的优势。在问答方面,FLOATER也略好于RoBERTa。

 

 

接下来看看在不同文本长度上各编码方案的优劣。如下图所见,当文本越长时,FLOATER的相对优势就越明显,这表明,FLOATER学到的编码函数 可以有较强的泛化能力。

 

 

其次,我们发现FLOATER和RNN是有一定的相似度的,这体现在位置编码的计算方式上,如果我们通过下面的方式(RNN)来计算位置编码又如何呢:

 

 

这里的 表示第 个位置,要幺是 (scalar),要幺是三角函数表示的向量(vector)。在得到整个位置编码序列之后,我们同样地把它们和Transformer的输入相加。

 

下表是几种计算位置编码方法的结果。 可以看到,用RNN去计算位置编码效果也不错,但都没有FLOATER好。

 

 

最后我们来看看几种位置编码的可视化,如下图所示。

 

显然,三角函数编码(a)的结构化程度最好,而参数化编码(b)就显得比较杂乱,RNN编码(d)几乎就没有结构化信息,而FLOATER(c)和三角函数编码比较类似,具有一定结构化信息。

 

注意到,并不是说结构化程度越高效果就越好,此处只是在阐释不同位置编码具有怎样的模式。

 

另一个值得注意的地方是,参数化编码(b)的底部几乎是常数,这是因为长文本在数据集中总的来说还是比较少的,所以这些比较远的位置就难以得到更新。

 

换句话说,参数化编码难以泛化到比较远的地方。而FLOATER(c)则不然,尽管长文本比较少,但是它仍然有很好的泛化能力。

 

 

小结

 

本文提出了一种基于连续动态系统的位置编码方法,可以不受文本长度的限制,可以从数据中学习,并且引入的参数量也不大。

 

实验表明,这种位置编码方式可以提升基线模型的表现,在机器翻译、自然语言理解和问答等任务上表现良好。

 

近些年来,ODE/PDE和神经网络结合的工作开始涌现,从物理上解释、提升神经网络是一条有前景的道路。

 

比如,Understanding and Improving Transformer From a Multi-Particle Dynamic System Point of View 这篇文章从ODE的角度试图解释Transformer,并且实现了很好的结果。我们期待未来有更多结合可解释性的文章。

 

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注