Press "Enter" to skip to content

中文NER任务实验小结报告——深入模型实现细节

作者:邱震宇( 华泰证券股份有限公司 算法工程师)

 

知乎专栏:我的ai之路

 

如标题所言,本文主要内容为最近一段时间内对中文NER任务的一些实践小结。文章内容包括对一些NER模型方法的介绍,以及各个方法在一个公开数据集上的实验比对。虽然这些模型方法或多或少为人所知,但是还是有一些细节部分在实践中才能注意到。我希望通过对这些细节的关注,加强自己模型实现的能力。

 

备注:本实验暂时聚焦于字符级的中文NER识别,因此使用的数据集也是字符级的。所以对于一些需要词级别的ner模型比如lattice-lstm等不会去涉及。不过,本人会介绍一些其他小技巧将词级别的信息融入到字符级的模型中。

 

任务描述

 

基于字符级的中文数据集,做实体识别任务。实体类型主要有三种:地点,机构,人物。数据采用BIO样式标注,样例如下:

 

海 钓 比 赛 地 点 在 厦 门 与 金 门 之 间 的 海 域 。

 

其中,B-LOC代表地点实体的开头,而I-LOC代表地点实体的剩余部分。O代表非实体。

 

训练集样本一共有12711条,验证集样本一共有1400条,测试集样本一共有2804条。数据集是从github上获得的,以字符级标注。

 

本任务评测关注的是span-level的micro f1分数。评测步骤如下:

 

1、对每个样本,根据不同方法得到预测的实体字符串。假设对于测试集来说,得到一个预测实体集合列表, ,其中, 表示一个样本中提取的所有实体的unique集合(不重复)。

 

2、对每个样本,抽取出真实的实体字符串。同样,对于测试集来说,得到一个真实实体集合列表。

 

3、计算mirco precision:,计算micro recall: ,

 

最后得到mirco f1:

 

下面,我就介绍一下针对该数据集任务,所使用的一些模型和技巧,以及其对应的效果比对。方法分为两大类。一类是非Bert类方法,主要包含经典的LSTM+CRF模型以及在其基础上进行一些优化的实践尝试;另一类则是BERT类方法,主要包含使用bert直接做传统的序列标注任务以及将NER问题转化为阅读理解问题做span位置预测任务,另外还对样本类别不均衡问题做了一些优化。最后,模型所有模型使用的超参数基本相同,没有进行精细调参。

 

非BERT类方法

 

本类方法主要基于的是LSTM+CRF架构,关于CRF的内容在上一篇文章中已经介绍,详见CMU NLP课程总结——条件随机场及Mininum Risk Training。简单提一下就是将文本序列经过char-embedding lookup之后,输入到双向LSTM中做序列编码,之后对输出的编码进行全连接转化,目标神经元维度为标签的个数,最后将输出作为CRF层的发射分数矩阵与CRF层进行结合,同时用一些初始化方法初始化转移分数矩阵,这个矩阵和发射分数矩阵都是可学习的,主要通过神经网络的梯度下降来更新参数。具体如图所示:

在实践中,使用了两层的Bi-lstm,因为对于LSTM来说,叠加深度并不能充分发挥深度网络的作用,有时候反而会降低整体的性能。同时使用的字符向量来自于github.com/Embedding/Ch中的字符+词+ngram的版本。最终,这个baseline模型的f1-mirco-avg为:0.870

 

baseline的改进尝试_1

 

首先,尝试将词向量的信息融入到模型中。目前有很多方法尝试将词和字符信息融合,比较好的方法有lattice-lstm以及lstm-cnn等。前者通过将字符信息融入到词级别的lstm网络中,对lstm进行网络结构的改造,而后者则使用卷积网络先对一个词内的字符进行编码然后再将该编码与词向量做融合(拼接或者相加)。我尝试的方法相对比较简单:对于该中文数据集来说,并不存在分词的信息,因此需要使用一些分词工具(jieba)先进行分词,然后在之前使用的char+word+ngram综合版本的词向量中找到该词的词向量,若没找到则初始化一个。之后,将该词向量与该词所属的所有字符向量进行融合,具体融合方法为:假设一个词有n个字符,则将词向量复制n个,然后分别与n个字符向量直接相加。具体如图所示:

这个方法对于分词错误的鲁棒性相对较好,且非常简单,不需要修改网络结构,不会降低模型本身的训练和预测效率。经过实践,经过该方法优化后的baseline模型的f1-micro-avg为:0.878

 

baseline的改进尝试_2

 

在上述改进模型的基础上,又尝试了在lstm编码的基础上增加CNN的编码层,试图加强整体模型的编码能力。具体做法很简单,即在lstm层的输出之后直接连接CNN层,通过CNN的编码输出再输入到后续网络层中。

经过实践,该优化后的baseline模型能够得到的f1-micro-avg分数为:0.882

 

另外,我还尝试了其他一些优化方法,但是均没有起效,这里就简单提一下,仅供参考:

 

1、对于PAD的标签处理,一般有两种处理方式:一是将PAD作为新的独立标签来预测;二是将PAD的位置都作为O标签处理。经过实践,两种方式效果差不多。

 

2、看到西湖大学一篇做序列标注的文章:Hierarchically-Refined Label Attention Network for Sequence Labeling。这篇文章核心在于不使用CRF,而是通过对标签列表信息做attention捕捉,将标签列表信息融入到标签的预测中。想法还蛮新颖的,但是在仔细研读了论文并同时研读了开源代码后,发觉其实现方式有些问题。代码中对标签列表信息计算attention时还是使用了mask操作,但是个人认为这个意义不大,因为标签列表并不存在先后依赖关系。而实践后也发现,该方法并不能达到很好的效果,所以暂时对此方法存疑,如果有同学使用该方法得到不错的效果,欢迎前来探讨。

 

BERT类方法

 

BERT类方法的baseline采用BERT论文中用BERT做序列标注任务的流程,即使用bert预训练模型做finetune。具体来说就是将序列输入到bert层得到一个序列的整个输出,然后用该输出作为序列的encoder表示。其后的操作有很多种,比如直接接全连接层、接lstm-crf层等,下面分别简单描述一下。

 

BERT+CE_loss

 

该方法直接在BERT层之后接全连接层,并使用交叉熵损失函数来进行学习。具体如图所示:

该模型最终的f1-micro-avg分数为:0.933

 

BERT+lstmcrf

 

该方法在BERT层之后,连接bilstm-crf的子结构,最后使用CRF的log_sum_exp分数作为目标函数。具体如图:

该模型的效果还是不错的,但是由于lstm+CRF本身效率就不高,再加上使用了BERT,可想而知其inference的速度确实比较慢,因此需要根据实际情况来选择。最终该模型得到的f1分数为:0.939

 

除了上述两种尝试外,还基于BERT做了一些其他的优化尝试,但是提升效果都不太明显,下面简单描述一下:

 

1、尝试用更少的标签列表。由上图可知,我使用的bert_baseline将[CLS]和[SEP]以及PAD都作为新的独立标签。我尝试了将这三个特殊符号对应的位置标签都视作O标签,相当于减小了标签的搜索空间,减小了模型预测标签的难度。最后的效果确实是有提升,提升了0.02左右,幅度不太大,可能不具备典型的参考价值。

 

2、对损失函数进行了优化尝试。用 BIO标注方式 做NER任务存在 标签类别分布不均衡 的问题。O标签的数据占据了文本大多数字符,而需要预测的非O标签则相对比较少。因此,我尝试了一些魔改loss来缓解这个不均衡问题,如尝试了focal loss,dice loss等。最后实践效果相对于基本的bert+celoss来说有0.03左右的提升,但是对于BERT+lstmcrf来说提升效果不大。

 

备注:bert的finetune需要注意learning rate的设置,这个在很多博客文章里都提到了,由于存在可能的遗忘灾难问题,因此用bert做transfer learning时要注意不要让bert保留的语言信息被遗忘掉,因此learning rate一般要设小一点。有些比较好的方法是对不同层的网络设置不同的learning rate。越靠近任务层的网络可以设置较大的learning rate,而越是上层的通用的网络层learning rate偏向于较小值。

 

上述都是基于传统BIO标注方式来做NER问题,最后一部分我重点总结一下用阅读理解的方式来做NER问题。该实现参考自以下论文:A Unified MRC Framework for Named Entity Recognition。

 

BERT+MRC

 

该方法的重点其实在于构造数据集,模型本身并不复杂。方法的核心思想为对于要抽取的每一类实体,构造关于该类实体的query,然后将需要抽取实体的原始文本作为passage上下文,预测该类实体在passage中的位置(一般是start_index和end_index),具体来说就是对于一个passage序列,预测其每个字符位置是否是该类实体的起始位置或结束位置,可以看出来此时任务转化成了一个多标签二分类问题,对于一个原始样本来说,假设序列长度为n,那幺就是n个二分类问题。设计baseline时,使用交叉熵作为损失函数,最终的损失函数由两部分组成:start_loss和end_loss。

 

相比于传统的序列标注做法,使用MRC方式做NER好处在于引入了query这个先验知识,比如对于LOC类别,我们构造这样的query:找出国家,城市,山川等抽象或具体的地点。模型通过attention机制,对于query中的国家,城市,山川词汇学习到了地点的关注信息,然后反哺到passage中的实体信息捕捉中。

 

当然,原始论文中提到该方法能够很好得解决的实体嵌套问题,因为该方法还设计了一个额外的loss:span_loss。核心思想是对于模型学习到的所有实体的start和end位置,构造首尾实体匹配任务,即判断某个start位置是否与某个end位置匹配为一个实体,是则预测为1,否则预测为0,相当于转化为一个二分类问题,正样本就是真实实体的匹配,负样本是非实体的位置匹配。

 

在实现时,需要对分类前的logits做argmax得到模型预测的start和end位置,然后根据训练样本中真实实体的start和end位置来计算交叉熵loss。但是argmax函数是不可微的,因此个人感觉无法直接在训练中使用,还需要对其做一些软化近似。由于官方没有给出这块的开源代码,且本数据集并不存在实体嵌套问题,因此我暂时没有实现这个loss。

 

构造query

 

假设目前我们要预测的实体类别个数为m,则我们需要构造m个不同的query。每个样本相当于扩充了m倍,得到m个新的样本。关于如何构造query,这个就需要人工来设计了。对于通用的location,person,organization类别,可以使用一些简单的query:

 

"ORG":"找出公司,商业机构,社会组织等组织机构",
"LOC":"找出国家,城市,山川等抽象或具体的地点",
"PER":"找出真实和虚构的人名"

 

构造训练数据

 

用BERT来做阅读理解任务首先需要构造相应格式的训练数据。目前假设我们对每个原始样本,构造了m个query。接下来就是将每个query和样本passage进行拼接,得到m个不同的bert输入数据,如下所示:

 

由上述描述可知经过构造后的一条数据样本是预测一个类别的实体的位置信息。

 

另外,在构造数据的时候有很多细节需要注意:

 

1、 序列长度问题 。由于bert最长只能接收512长度的序列,因此很多情况下都需要截断。在之前的传统BERTbaseline方法中,我就直接进行了截断,并把超过长度的部分直接丢弃。在本方法尝试时,我选择了保留了超出长度的部分,直接将其作为新的样本来处理,并且为了防止实体词被截断,我还使用了一些规则,让同一个实体的字符尽量保留在一个样本中。当然,对于测试数据来说,默认是没有实体标签的,因此就直接按照长度直接截断,保留超出长度序列作为新的样本。

 

2、 对序列进行mask 。由于我们搜索实体的范围仅局限与上下文passage,不包括query。但是bert处理的是query+passage整体序列。在最后计算loss的时候,我们需要将query部分(以及cls,sep、pad等特殊字符)mask掉,使其在计算loss时被忽略。

 

模型结构

 

模型结构反而是比较简单的,可以参考bert论文中的做法,具体如图:

经过实践,该模型对任务评测结果的提升还是很可观的,最终的f1-micro-avg为:0.955。

 

对于该模型,我主要的优化点在于其loss的设计上。对于该模型来说,同样存在样本分布不均衡的问题,即标签为0的数据是占据大多数的,且相对于BIO序列标注任务来说,这个方法的不均衡问题更严重。因此我尝试了focal loss,dice loss等方法,其中dice loss参考自论文:Dice Loss for Data-imbalanced NLP Tasks。但是对于dice loss来说,经过实践并没有得到预期的效果,只能等待官方的开源的了。(吐槽一下最近碰到很多无法复现所述效果的论文,只能说自己在论文鉴别方面还需要加强。)最后使用了focal loss,对任务有一定的提升,大概提升了0.03。

 

实验结果汇总

 

 

小结

 

本文主要就中文NER任务做了一系列的实践小结,包括非BERT类的传统方法以及BERT类的方法,并且关注了一些论文中不会提及的一些实现细节问题,希望本文的一些内容能够对其他业界的同行有一些参考作用。最后贴一下项目的开源地址

 

https://github.com/qiufengyuyi/sequence_tagging

 

qiufengyuyi/sequence_tagging github.com

 

目前该项目的结构还不是很完善,需要进行一些重构。但是bert+mrc的运行还是没有问题的,欢迎拍砖。

 

Be First to Comment

发表评论

电子邮件地址不会被公开。 必填项已用*标注