Press "Enter" to skip to content

抱抱脸系列-多类别文本分类

“多类别文本分类”和“多标签文本分类”区别在于前者每个样本只有一个类别,而后者每个样本可能有多个类别

 

import pandas as pd
import torch
import transformers
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertModel, DistilBertTokenizer

 

from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'

 

数据集使用 archive.ics.uci.edu/ml/datasets…

 

取其中的 newsCorpora.csv 文件,只使用文件中的 title 和 category 字段

 

df = pd.read_csv('./NewsAggregatorDataset/newsCorpora.csv', sep='\t', names=['ID','TITLE', 'URL', 'PUBLISHER', 'CATEGORY', 'STORY', 'HOSTNAME', 'TIMESTAMP'])
df = df[["TITLE", "CATEGORY"]]
df.head()

 

# 标签映射字典
my_dict = {
    'e':'Entertainment',
    'b':'Business',
    't':'Science',
    'm':'Health'
}
def update(x):
    return my_dict[x]
df['CATEGORY'] = df["CATEGORY"].apply(lambda x: update(x))
df.head()

 

encode_dict = {}
def encode_cat(x):
    if x not in encode_dict.keys():
        encode_dict[x]=len(encode_dict)
    return encode_dict[x]
df['ENCODE_CAT'] = df['CATEGORY'].apply(lambda x: encode_cat(x))
df.head()

 

定义超参数,这次使用 distilbert-base-cased

 

MAX_LEN = 512
TRAIN_BATCH_SIZE = 4
VALID_BATCH_SIZE = 2
EPOCHS = 1
LEARNING_RATE = 1e-05
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')

 

使用 torch Dataset 类定义数据集

 

注意:DistilBert 不需要 token type id 输入

 

class Triage(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.len = len(dataframe)
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __getitem__(self, index):
        title = str(self.data.TITLE[index])
        title = " ".join(title.split())
        inputs = self.tokenizer.encode_plus(
            title,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            pad_to_max_length=True,
            return_token_type_ids=True,
            truncation=True
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'targets': torch.tensor(self.data.ENCODE_CAT[index], dtype=torch.long)
        } 
    
    def __len__(self):
        return self.len

 

分割数据集,为了快速跑一个结果,只用 2000 条数据集跑一下

 

train_size = 0.8
train_dataset=df.sample(frac=train_size,random_state=200)
test_dataset=df.drop(train_dataset.index).reset_index(drop=True)
train_dataset = train_dataset.reset_index(drop=True)

print("FULL Dataset: {}".format(df.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("TEST Dataset: {}".format(test_dataset.shape))
# 为了迅速训练一下,我这里只截断数据集2000条走一遍流程
train_dataset = train_dataset[:2000]
test_dataset = test_dataset[:100]
training_set = Triage(train_dataset, tokenizer, MAX_LEN)
testing_set = Triage(test_dataset, tokenizer, MAX_LEN)

 

使用 torch DataLoader 构造数据迭代器

 

train_params = {'batch_size': TRAIN_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }
test_params = {'batch_size': VALID_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }
training_loader = DataLoader(training_set, **train_params)
testing_loader = DataLoader(testing_set, **test_params)

 

构造网络

 

huggingface.co/distilbert-…
下载 distilbert-base-uncased 所需文件并放在名为 “distilbert-base-uncased” 的文件夹下

 

class DistillBERTClass(torch.nn.Module):
    def __init__(self):
        super(DistillBERTClass, self).__init__()
        self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.pre_classifier = torch.nn.Linear(768, 768)
        self.dropout = torch.nn.Dropout(0.3)
        self.classifier = torch.nn.Linear(768, 4)
    def forward(self, input_ids, attention_mask):
        output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = output_1[0]  # size 为 (batch size, max_len, 768)
        pooler = hidden_state[:, 0]  # size 为 (batch size, 768)
        pooler = self.pre_classifier(pooler)
        pooler = torch.nn.ReLU()(pooler)
        pooler = self.dropout(pooler)
        output = self.classifier(pooler)
        return output
model = DistillBERTClass()
model.to(device)

 

设置损失函数和优化器

 

loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params =  model.parameters(), lr=LEARNING_RATE)

 

def calcuate_accu(big_idx, targets):
    n_correct = (big_idx==targets).sum().item()
    return n_correct

 

训练步骤

 

from tqdm import tqdm
def train(epoch):
    model.train()
    for _,data in tqdm(enumerate(training_loader, 0)):
        ids = data['ids'].to(device, dtype = torch.long)
        mask = data['mask'].to(device, dtype = torch.long)
        targets = data['targets'].to(device, dtype = torch.long)
        outputs = model(ids, mask)
        loss = loss_function(outputs, targets)
        if _%100==0:
            print(f'Epoch: {epoch}, Loss:  {loss.item()}')
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return

 

for epoch in range(EPOCHS):
    train(epoch)

 

验证模型

 

def valid(model, testing_loader):
    model.eval()
    n_correct = 0
    tr_loss = 0
    nb_tr_steps = 0
    nb_tr_examples = 0
    with torch.no_grad():
        for _, data in enumerate(testing_loader, 0):
            ids = data['ids'].to(device, dtype = torch.long)
            mask = data['mask'].to(device, dtype = torch.long)
            targets = data['targets'].to(device, dtype = torch.long)
            outputs = model(ids, mask).squeeze()
            loss = loss_function(outputs, targets)
            tr_loss += loss.item()
            big_val, big_idx = torch.max(outputs.data, dim=1)
            n_correct += calcuate_accu(big_idx, targets)
            nb_tr_steps += 1
            nb_tr_examples+=targets.size(0)
    epoch_loss = tr_loss/nb_tr_steps
    epoch_accu = (n_correct*100)/nb_tr_examples
    print(f"Validation Loss Epoch: {epoch_loss}")
    print(f"Validation Accuracy Epoch: {epoch_accu}")
    
    return epoch_accu
acc = valid(model, testing_loader)
print("Accuracy on test data = %0.2f%%" % acc)
# Validation Loss Epoch: 1.2351238656044006
# Validation Accuracy Epoch: 58.0
# Accuracy on test data = 58.00%

 

保存模型和 tokenizer

 

output_model_file = 'models/pytorch_distilbert_news.bin'
output_vocab_dir = 'models/vocab_distilbert/'
model_to_save = model
torch.save(model_to_save.state_dict(), output_model_file)
tokenizer.save_pretrained(output_vocab_dir)

 

加载模型和 tokenizer

 

load_model = DistillBERTClass()
load_model.load_state_dict(torch.load(output_model_file))

 

load_tokenizer = DistilBertTokenizer.from_pretrained(output_vocab_dir)

 

测试一下 tokenizer 是否加载成功

 

ori_test_inputs = tokenizer.encode_plus(
    "US open: Stocks fall after Fed official",
    None,
    add_special_tokens=True,
    max_length=MAX_LEN,
    pad_to_max_length=True,
    return_token_type_ids=True,
    truncation=True)
load_test_inputs = load_tokenizer.encode_plus(
    "US open: Stocks fall after Fed official",
    None,
    add_special_tokens=True,
    max_length=MAX_LEN,
    pad_to_max_length=True,
    return_token_type_ids=True,
    truncation=True)
load_test_inputs == ori_test_inputs

 

测试一下模型是否加载成功

 

load_model.eval()
test_one_ids = torch.tensor([load_test_inputs['input_ids']], dtype=torch.long)
test_one_mask = torch.tensor([load_test_inputs['attention_mask']], dtype=torch.long)
res2 = load_model(test_one_ids, test_one_mask)
res1 = model(test_one_ids, test_one_mask)
res2 == res1

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注