Press "Enter" to skip to content

使用 Transformers 在你自己的数据集上训练文本分类模型

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

最近实在是有点忙,没啥时间写博客了。趁着周末水一文,把最近用huggingface transformers
训练文本分类模型时遇到的一个小问题说下。

 

背景

 

之前只闻 transformers 超厉害超好用,但是没有实际用过。之前涉及到 bert 类模型都是直接手写或是在别人的基础上修改。但这次由于某些原因,需要快速训练一个简单的文本分类模型。其实这种场景应该挺多的,例如简单的 POC 或是临时测试某些模型。

 

我的需求很简单:用我们自己的
数据集,快速
训练一个文本分类模型,验证想法。

 

我觉得如此简单的一个需求,应该有模板代码。但实际去搜的时候发现,官方文档什幺时候变得这幺多这幺庞大了?还多了个Trainer
API?瞬间让我想起了 Pytorch Lightning 那个坑人的同名 API
。但可能是时间原因,找了一圈没找到适用于自定义数据集的代码,都是用的官方、预定义的数据集。

 

所以弄完后,我决定简单写一个文章,来说下这原本应该极其容易解决的事情。

 

数据

 

假设我们数据的格式如下:

 

0 第一个句子
1 第二个句子
0 第三个句子

 

即每一行都是label sentence
的格式,中间空格分隔。并且我们已将数据集分成了train.txt
val.txt

 

代码

 

加载数据集

 

首先使用datasets
加载数据集:

 

from datasets import load_dataset
dataset = load_dataset('text', data_files={'train': 'data/train_20w.txt', 'test': 'data/val_2w.txt'})

 

加载后的dataset
是一个DatasetDict
对象:

 

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 3
    })
    test: Dataset({
        features: ['text'],
        num_rows: 3
    })
})

 

类似tf.data
,此后我们需要对其进行map
,对每一个句子进行 tokenize、padding、batch、shuffle:

 

def tokenize_function(examples):
    labels = []
    texts = []
    for example in examples['text']:
        split = example.split(' ', maxsplit=1)
        labels.append(int(split[0]))
        texts.append(split[1])
    tokenized = tokenizer(texts, padding='max_length', truncation=True, max_length=32)
    tokenized['labels'] = labels
    return tokenized
tokenized_datasets = dataset.map(tokenize_function, batched=True)
train_dataset = tokenized_datasets["train"].shuffle(seed=42)
eval_dataset = tokenized_datasets["test"].shuffle(seed=42)

 

根据数据集格式不同,我们可以在tokenize_function
中随意自定义处理过程,以得到 text 和 labels。注意batch_size
max_length
也是在此处指定。处理完我们便得到了可以输入给模型的训练集和测试集。

 

训练

 

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2, cache_dir='data/pretrained')
training_args = TrainingArguments('ckpts', per_device_train_batch_size=256, num_train_epochs=5)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset
)
trainer.train()

 

你可以根据情况修改训练 batchsizeper_device_train_batch_size

Be First to Comment

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注