Press "Enter" to skip to content

深度学习文本分类模型&代码&技巧(TextCNN/DPCNN/TextRCNN/HAN/BERT)

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

文本分类是NLP的必备入门任务,在搜索、推荐、对话等场景中随处可见,并有情感分析、新闻分类、标签分类等成熟的研究分支和数据集。

 

本文主要介绍深度学习文本分类的常用模型原理、优缺点以及技巧,是「NLP入门指南」的其中一章,之后会不断完善,欢迎提意见。

 

P.S. 文末附NLP学习路线资料、学习群进入的方式~

 

Fasttext

 

论文:https://arxiv.org/abs/1607.01759
代码:https://github.com/facebookresearch/fastText

 

Fasttext是Facebook推出的一个便捷的工具,包含文本分类和词向量训练两个功能。

 

Fasttext的分类实现很简单:把输入转化为词向量,取平均,再经过线性分类器得到类别。输入的词向量可以是预先训练好的,也可以随机初始化,跟着分类任务一起训练。

Fasttext直到现在还被不少人使用,主要有以下优点:

 

 

    1. 模型本身复杂度低,但效果不错,能快速产生任务的baseline

 

    1. Facebook使用C++进行实现,进一步提升了计算效率

 

    1. 采用了char-level的n-gram作为附加特征,比如paper的trigram是 [pap, ape, per],在将输入paper转为向量的同时也会把trigram转为向量一起参与计算。这样一方面解决了长尾词的OOV (out-of-vocabulary)问题,一方面利用n-gram特征提升了表现

 

    1. 当类别过多时,支持采用hierarchical softmax进行分类,提升效率

 

 

对于文本长且对速度要求高的场景,Fasttext是baseline首选。同时用它在无监督语料上训练词向量,进行文本表示也不错。不过想继续提升效果还需要更复杂的模型。

 

TextCNN

 

论文:https://arxiv.org/abs/1408.5882
代码:https://github.com/yoonkim/CNN_sentence

 

TextCNN是Yoon Kim小哥在2014年提出的模型,开创了用CNN编码n-gram特征的先河。

模型结构如图,图像中的卷积都是二维的,而TextCNN则使用「一维卷积」,即 filter_size * embedding_dim ,有一个维度和embedding相等。这样filter_size就能抽取n-gram的信息。以1个样本为例,整体的前向逻辑是:

 

[seq_length, embedding_dim]
seq_length-filter_size+1
1x1

 

在TextCNN的实践中,有很多地方可以优化(参考这篇论文 [1] ):

 

 

    1. Filter尺寸:这个参数决定了抽取n-gram特征的长度,这个参数主要跟数据有关,平均长度在50以内的话,用10以下就可以了,否则可以长一些。在调参时可以先用一个尺寸grid search,找到一个最优尺寸,然后尝试最优尺寸和附近尺寸的组合

 

    1. Filter个数:这个参数会影响最终特征的维度,维度太大的话训练速度就会变慢。这里在100-600之间调参即可

 

    1. CNN的激活函数:可以尝试Identity、ReLU、tanh

 

    1. 正则化:指对CNN参数的正则化,可以使用dropout或L2,但能起的作用很小,可以试下小的dropout率(<0.5),L2限制大一点

 

    1. Pooling方法:根据情况选择mean、max、k-max pooling,大部分时候max表现就很好,因为分类任务对细粒度语义的要求不高,只抓住最大特征就好了

 

    1. Embedding表:中文可以选择char或word级别的输入,也可以两种都用,会提升些效果。如果训练数据充足(10w+),也可以从头训练

 

    1. 蒸馏BERT的logits,利用领域内无监督数据

 

    1. 加深全连接:原论文只使用了一层全连接,而加到3、4层左右效果会更好 [2]

 

 

TextCNN是很适合中短文本场景的强baseline,但不太适合长文本,因为卷积核尺寸通常不会设很大,无法捕获长距离特征。同时max-pooling也存在局限,会丢掉一些有用特征。另外再仔细想的话,TextCNN和传统的n-gram词袋模型本质是一样的,它的好效果很大部分来自于词向量的引入 [3] ,解决了词袋模型的稀疏性问题。

 

DPCNN

 

论文:https://ai.tencent.com/ailab/media/publications/ACL3-Brady.pdf
代码:https://github.com/649453932/Chinese-Text-Classification-Pytorch

 

上面介绍TextCNN有太浅和长距离依赖的问题,那直接多怼几层CNN是否可以呢?感兴趣的同学可以试试,就会发现事情没想象的那幺简单。直到2017年,腾讯才提出了把TextCNN做到更深的DPCNN模型:

上图中的ShallowCNN指TextCNN。DPCNN的核心改进如下:

 

 

    1. 在Region embedding时不采用CNN那样加权卷积的做法,而是 对n个词进行pooling后再加个1×1的卷积 ,因为实验下来效果差不多,且作者认为前者的表示能力更强,容易过拟合

 

    1. 使用1/2池化层,用size=3 stride=2的卷积核,直接 让模型可编码的sequence长度翻倍 (自己在纸上画一下就get啦)

 

    1. 残差链接,参考ResNet,减缓梯度弥散问题

 

 

凭借以上一些精妙的改进,DPCNN相比TextCNN有1-2个百分点的提升。

 

TextRCNN

 

论文:https://dl.acm.org/doi/10.5555/2886521.2886636
代码:https://github.com/649453932/Chinese-Text-Classification-Pytorch

 

除了DPCNN那样增加感受野的方式,RNN也可以缓解长距离依赖的问题。下面介绍一篇经典TextRCNN。

模型的前向过程是:

 

 

    1. 得到单词 i 的表示

    1. 通过RNN得到左右双向的表示

    1. 将表示拼接得到

    1. ,再经过变换得到

    1. 对多个

    1. 进行 max-pooling,得到句子表示 ,在做最终的分类

 

 

这里的convolutional是指max-pooling。通过加入RNN,比纯CNN提升了1-2个百分点。

 

TextBiLSTM+Attention

 

论文:https://www.aclweb.org/anthology/P16-2034.pdf
代码:https://github.com/649453932/Chinese-Text-Classification-Pytorch

 

从前面介绍的几种方法,可以自然地得到文本分类的框架,就是 先基于上下文对token编码,然后pooling出句子表示再分类 。在最终池化时,max-pooling通常表现更好,因为文本分类经常是主题上的分类,从句子中一两个主要的词就可以得到结论,其他大多是噪声,对分类没有意义。而到更细粒度的分析时,max-pooling可能又把有用的特征去掉了,这时便可以用attention进行句子表示的融合:

BiLSTM就不解释了,要注意的是,计算attention score时会先进行变换:

 

 

其中 是context vector,随机初始化并随着训练更新。最后得到句子表示 ,再进行分类。

 

这个加attention的套路用到CNN编码器之后代替pooling也是可以的,从实验结果来看attention的加入可以提高2个点。如果是情感分析这种由句子整体决定分类结果的任务首选RNN。

 

HAN

 

论文:https://www.aclweb.org/anthology/N16-1174.pdf
代码:https://github.com/richliao/textClassifier

 

上文都是句子级别的分类,虽然用到长文本、篇章级也是可以的,但速度精度都会下降,于是有研究者提出了层次注意力分类框架,即Hierarchical Attention。先对每个句子用 BiGRU+Att 编码得到句向量,再对句向量用 BiGRU+Att 得到doc级别的表示进行分类:

方法很符合直觉,不过实验结果来看比起avg、max池化只高了不到1个点(狗头,真要是很大的doc分类,好好清洗下,fasttext其实也能顶的(捂脸。

 

BERT

 

BERT的原理代码就不用放了叭~

 

BERT分类的优化可以尝试:

 

 

    1. 多试试不同的预训练模型,比如RoBERT、WWM、ALBERT

 

    1. 除了 [CLS] 外还可以用 avg、max 池化做句表示,甚至可以把不同层组合起来

 

    1. 在领域数据上增量预训练

 

    1. 集成蒸馏,训多个大模型集成起来后蒸馏到一个上

 

    1. 先用多任务训,再迁移到自己的任务

 

 

其他模型

 

除了上述常用模型之外,还有Capsule Network [4] 、TextGCN [5] 等红极一时的模型,因为涉及的背景知识较多,本文就暂不介绍了(嘻嘻)。

 

虽然实际的落地应用中比较少见,但在机器学习比赛中还是可以用的。Capsule Network被证明在多标签迁移的任务上性能远超CNN和LSTM [6] ,但这方面的研究在18年以后就很少了。TextGCN则可以学到更多的global信息,用在半监督场景中,但碰到较长的需要序列信息的文本表现就会差些 [7] 。

 

技巧

 

模型说得差不多了,下面介绍一些自己的数据处理血泪经验,如有不同意见欢迎讨论~

 

数据集构建

 

首先是 标签体系的构建 ,拿到任务时自己先试标一两百条,看有多少是难确定(思考1s以上)的,如果占比太多,那这个任务的定义就有问题。可能是标签体系不清晰,或者是要分的类目太难了,这时候就要找项目owner去反馈而不是继续往下做。

 

其次是 训练评估集的构建 ,可以构建两个评估集,一个是贴合真实数据分布的线上评估集,反映线上效果,另一个是用规则去重后均匀采样的随机评估集,反映模型的真实能力。训练集则尽可能和评估集分布一致,有时候我们会去相近的领域拿现成的有标注训练数据,这时就要注意调整分布,比如句子长度、标点、干净程度等,尽可能做到自己分不出这个句子是本任务的还是从别人那里借来的。

 

最后是 数据清洗 :

 

 

    1. 去掉文本强pattern:比如做新闻主题分类,一些爬下来的数据中带有的XX报道、XX编辑高频字段就没有用,可以对语料的片段或词进行统计,把很高频的无用元素去掉。还有一些会明显影响模型的判断,比如之前我在判断句子是否为无意义的闲聊时,发现加个句号就会让样本由正转负,因为训练预料中的闲聊很少带句号(跟大家的打字习惯有关),于是去掉这个pattern就好了不少

 

    1. 纠正标注错误:这个我真的屡试不爽,生生把自己从一个算法变成了标注人员。简单的说就是把训练集和评估集拼起来,用该数据集训练模型两三个epoch(防止过拟合),再去预测这个数据集,把模型判错的拿出来按 abs(label-prob) 排序,少的话就自己看,多的话就反馈给标注人员,把数据质量搞上去了提升好几个点都是可能的

 

 

长文本

 

任务简单的话(比如新闻分类),直接用fasttext就可以达到不错的效果。

 

想要用BERT的话,最简单的方法是粗暴截断,比如只取句首+句尾、句首+tfidf筛几个词出来;或者每句都预测,最后对结果综合。

 

另外还有一些魔改的模型可以尝试,比如XLNet、Reformer、Longformer。

 

如果是离线任务且来得及的话还是建议跑全部,让我们相信模型的编码能力。

 

少样本

 

自从用了BERT之后,很少受到数据不均衡或者过少的困扰,先无脑训一版。

 

如果样本在几百条,可以先把分类问题转化成匹配问题,或者用这种思想再去标一些高置信度的数据,或者用自监督、半监督的方法。

 

鲁棒性

 

在实际的应用中,鲁棒性是个很重要的问题,否则在面对badcase时会很尴尬,怎幺明明那样就分对了,加一个字就错了呢?

 

这里可以直接使用一些粗暴的数据增强,加停用词加标点、删词、同义词替换等,如果效果下降就把增强后的训练数据洗一下。

 

当然也可以用对抗学习、对比学习这样的高阶技巧来提升,一般可以提1个点左右,但不一定能避免上面那种尴尬的情况。

 

总结

 

文本分类是工业界最常用的任务,同时也是大多数NLPer入门做的第一个任务,我当年就是啥都不会,从训练到部署地实践了文本分类后就顺畅了。上文给出了不少模型,但实际任务中常用的也就那几个,下面是快速选型的建议:

实际上,落地时主要还是和数据的博弈。数据决定模型的上限,大多数人工标注的准确率达到95%以上就很好了,而文本分类通常会对准确率的要求更高一些,与其苦苦调参想fancy的结构,不如好好看看badcase,做一些数据增强提升模型鲁棒性更实用。

 

Be First to Comment

发表评论

邮箱地址不会被公开。 必填项已用*标注