Press "Enter" to skip to content

超越bert,最新预训练模型 ELECTRA 论文阅读笔记

最近看到一些ELECTRA模型的消息,感觉这个模型的设计很新颖,于是赶紧找来了原文来看一下,看完趁着现在还有空赶紧记下来。

 

论文的地址 https://openreview.net/pdf?id=r1xMH1BtvB

 

目前论文还在ICLR2020的双盲审阶段,据说,作者为斯坦福SAIL实验室Manning组。

 

文章贡献:

 

文章提出了一种新的文本预训练模型,相比于之前的预训练模型(xlnet,bert等),该模型用更少的计算资源消耗和更少的参数在GLUE上取得了超越xlnet,roberta的成绩。

 

文章细节:

 

文章开始,作者就先比较了ELECTRA模型和其他预训练模型在GLUE任务上的效果和训练资源的消耗。可以看到本文提出的ELECTRA模型在训练步长更少的前提下得到了比其他预训练模型更好的效果。

作者提出的预训练模型如下图所示,包括了一个Generator和一个Discriminator,在预训练模型的训练中,需要训练两个神经网络。一个生成器G,一个分类器D,对于生成网络

G,对于给定的输入 x ,生成网络在t位置按照下述方式生成输出token的xt。其中 e 是token对应的embedding。

 

 

而对于给定位置 t 的token,鉴别器需要去鉴别这个单词是否经过替换。

 

 

生成器模型G会扮演一个MLM模型,对于给定的token输入x=[x1,x2,x3..xn],它会选择一些位置,并且将token替换成[mask],然后生成器会开始学习预测这个masked的元素。并且生成自己的预测结果,并且生成句子。

 

而鉴别器模型D会对于生成句子中的每一个token进行鉴别,判断其是否是经过G替换的token,与GAN网络不同,在这个G,D网络中G使用最大似然的方式训练,而不是训练去用于对抗D网络,由于在文本任务中D的loss不便于传给G,所以D和G的训练是分开的,联合的训练loss如下:

 

 

除了上述的基本模型以外,作者在训练中还增加了下述的方法:

 

embedding层共享:作者认为在G模型生成任务中,G模型会对token的embedding层进行学习,从而得到更好的embedding,而D模型不会更新其embedding层,所以作者将两个模型的embedding层进行了共享。

 

使用小型生成器G:作者认为如果生成器G和鉴别器D如果使用同样的大小,那幺训练的效率会下降,于是作者在仅改变层数大小的情况下做了实验,如下图所示。

实验发现,当G网络的层数在D网络的1/4到1/2的时候效果最好。作者认为太强的生成器会让鉴别器不能高效地学习。

 

训练算法:除了文中上述的两种方法,作者还实验了另外两种训练算法:

 

1、Adversarial Contrastive Estimation:借用强化学习的思想,将被替换token的交叉熵作为生成器的reward,然后进行梯度下降,用GAN的思想进行训练。

 

2、Two-stage training:先训练生成器,然后用生成器的权重初始化判别器,再训练判别器,得到的实验结果如下所示:

实验结果:作者先给出了ELECTRA-Small和ELECTRA-base的实验结果,在等参数级别和训练量的情况下比较了BERT,ELMo,GPT模型的效果。可以看到在等参数,等计算量的情况下,ELECTRA模型有极大的效果提升。

接着作者给出了ELECTRA-Large的实验结果,从实验结果中可以看出ELECTRA-Large模型在等参数量,但是仅使用Roberta1/4计算量的前提下获得了超过Roberta的效果。

之后作者为了探究效率提升的原因,进行了对比实验,比较不同方案的效果,作者对以下训练方案进行了实验:

 

1、 ELECTRA 15%:让D只计算15%token上的损失。

 

2、 Replace MLM:输入不用mask替换,而是直接用其他生成器。

 

3、 All-Tokens MLM: 依然使用替换的方法,但是目标函数变为预测所有的token。

实验结果表明,mask的使用确实会降低一定的性能,尽管BERT在mask token上已经使用了一些trick,预测所有的token效果会好于预测部分的token。

 

最后作者比较了BERT和ELECTRA在不同hidden state size和训练计算量下的表现,结果如下:

实验结果表明,在相同的hidden state size和训练计算量的情况下,ELECTRA模型相比于BERT有很好的性能提升。

Be First to Comment

发表评论

电子邮件地址不会被公开。 必填项已用*标注