Press "Enter" to skip to content

【NLP】XLNet详解

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

这次写文本来想把从Transformer-XL到XLNet讲一下,但是自己还没看,先写一下XLNet的部分,Transformer-XL明天补上~

 

1. 背景

 

2018年10月的时候,谷歌放出了称霸GLUE榜单的BERT模型,当时BERT最大的创新就是提出了Masked Language Model作为预训练任务,解决了GPT不能双向编码、ELMo不能深度双向编码的问题。之后从那天起,很多任务都不再需要复杂的网络结构,也不需要大量的标注数据,业界学术界都基于BERT做了很多事情。

 

昨天,也就是2019年6月19日,谷歌又放出了一个模型XLNet,找到并解决了BERT的缺点,刷爆了BERT之前的成绩(当然数据、算力相比去年都增加了很多)。惊醒大家不要总被牵着走,创新总会从某个缝里钻出来。

 

首先介绍两种无监督目标函数:

 

 

    1. AR(autoregressive):自回归,假设序列数据存在线性关系,用

    1. 。以前传统的单向语言模型(ELMo、GPT)都是以AR作为目标。

 

    1. AE(autoencoding):自编码,将输入复制到输出。BERT的MLM就是AE的一种。

 

 

AR是以前常用的方法,但缺点是不能进行双向的编码。因此BERT采用了AE,获取到序列全局的信息。但本文作者指出了BERT采用AE方法带来的两个问题:

 

 

    1. BERT有个不符合真实情况的假设:即被mask掉的token是相互独立的。比如预训练时输入:“自然[Mask][Mask]处理”,目标函数其实是 p(语|自然处理)+p(言|自然处理),而如果使用AR则应该是 p(语|自然)+p(言|自然语)。这样下来BERT得到的概率分布也是基于这个假设的,忽略了这些token之间的联系。

 

    1. BERT在预训练和精调阶段存在差异:因为在预训练阶段大部分输入都包含[Mask],引入了噪声,即使在小部分情况下使用了其他token,但仍与真实数据存在差异。

 

 

以上就是BERT采用AE方法存在的痛点,接下来请看XLNet如何解决这些问题。

 

2.

 

与其说XLNet解决了BERT的问题,不如说它基于AR采用了一种新的方法实现双向编码,因为AR方法不存在上述两个痛点。

 

XLNet的创新点是 Permutation Language Modeling ,如下图:

理论上

对于长度为T的序列 x ,存在T!种排列方法,如果把 重新排列成 ,再采用AR为目标函数,则优化的似然为

 

因为对于不同的排列方式,模型参数是共享的,所以模型最终可以学习到如何聚集所有位置的信息。

操作上

由于计算复杂度的限制,不可能计算所有的序列排列,因此对于每个序列输入只采样一个排列方式。而且在实际训练时,不会打乱序列,而是通过mask矩阵实现permutation。作者特意强调,这样可以保持与finetune输入顺序的一致,不会存在pretrain-finetune差异。

 

2.1 Two-Stream Self-Attention

 

解决了核心问题,接下来就是实现的细节问题了。其实上面打乱顺序后有一个很大的问题,就是在预测第三个x的时候模型预测的是 ,如果把排列方式换成 ,则应该预测 ,但模型是不知道当前要预测的是哪一个,因此输出的值是一样的,即 ,这就不对了。所以说要 加入位置信息 ,即 ,让模型知道目前是预测哪个位置的token。

 

那下一个问题又来了,传统的attention只带有token编码,位置信息都在编码里了,而AR目标是不允许模型看到当前token编码的,因此要把position embedding拆出来。怎幺拆呢?作者就提出了Two-Stream Self-Attention。

 

Query stream:只能看到当前的位置信息,不能看到当前token的编码

 

 

Content stream:传统self-attention,像GPT一样对当前token进行编码

 

 

预训练阶段最终预测只使用query stream,因为content stream已经见过当前token了。在精调阶段使用content stream,又回到了传统的self-attention结构。

 

下面的图起码看3遍~看懂为止,图比我讲的明白。。

另外,因为不像MLM只用预测部分token,还需要计算permutation,XLNet的计算量更大了,因此作者提出了partial prediction进行简化,即只预测后面1/K个token。

 

2.2 Transformer-XL

 

为了学习到更长距离的信息,作者沿用了自己的Transformer-XL。

One Comment

  1. 感谢楼主,小弟这里有几个点不清楚,望楼主开导
    – 在Figure 2(a)中的红线代表着什么?(根据Transformer, https://zhuanlan.zhihu.com/p/48508221 里面解释transfomer的input应该是1个x,然后相乘3个weight)
    – 在Figure 2 (C)中Attention Mask, 里面的红点代表着什么?3 -> 2 -> 4 -> 1 是不是代表应该算出 P(x_1 | x_3, x_2, x_4 )?这个Attention Mask的部分不是很明白,希望楼主能够写点解释

    注释: 楼主好像不小心把Content-stream的方程式和Query Stream的方程式放到一样的了,Content-stream的应该是h(m)之类的

发表评论

邮箱地址不会被公开。 必填项已用*标注