“多类别文本分类”和“多标签文本分类”区别在于前者每个样本只有一个类别,而后者每个样本可能有多个类别
import pandas as pd import torch import transformers from torch.utils.data import Dataset, DataLoader from transformers import DistilBertModel, DistilBertTokenizer
from torch import cuda device = 'cuda' if cuda.is_available() else 'cpu'
数据集使用 archive.ics.uci.edu/ml/datasets…
取其中的 newsCorpora.csv 文件,只使用文件中的 title 和 category 字段
df = pd.read_csv('./NewsAggregatorDataset/newsCorpora.csv', sep='\t', names=['ID','TITLE', 'URL', 'PUBLISHER', 'CATEGORY', 'STORY', 'HOSTNAME', 'TIMESTAMP']) df = df[["TITLE", "CATEGORY"]] df.head()
# 标签映射字典 my_dict = { 'e':'Entertainment', 'b':'Business', 't':'Science', 'm':'Health' } def update(x): return my_dict[x] df['CATEGORY'] = df["CATEGORY"].apply(lambda x: update(x)) df.head()
encode_dict = {} def encode_cat(x): if x not in encode_dict.keys(): encode_dict[x]=len(encode_dict) return encode_dict[x] df['ENCODE_CAT'] = df['CATEGORY'].apply(lambda x: encode_cat(x)) df.head()
定义超参数,这次使用 distilbert-base-cased
MAX_LEN = 512 TRAIN_BATCH_SIZE = 4 VALID_BATCH_SIZE = 2 EPOCHS = 1 LEARNING_RATE = 1e-05 tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
使用 torch Dataset 类定义数据集
注意:DistilBert 不需要 token type id 输入
class Triage(Dataset): def __init__(self, dataframe, tokenizer, max_len): self.len = len(dataframe) self.data = dataframe self.tokenizer = tokenizer self.max_len = max_len def __getitem__(self, index): title = str(self.data.TITLE[index]) title = " ".join(title.split()) inputs = self.tokenizer.encode_plus( title, None, add_special_tokens=True, max_length=self.max_len, pad_to_max_length=True, return_token_type_ids=True, truncation=True ) ids = inputs['input_ids'] mask = inputs['attention_mask'] return { 'ids': torch.tensor(ids, dtype=torch.long), 'mask': torch.tensor(mask, dtype=torch.long), 'targets': torch.tensor(self.data.ENCODE_CAT[index], dtype=torch.long) } def __len__(self): return self.len
分割数据集,为了快速跑一个结果,只用 2000 条数据集跑一下
train_size = 0.8 train_dataset=df.sample(frac=train_size,random_state=200) test_dataset=df.drop(train_dataset.index).reset_index(drop=True) train_dataset = train_dataset.reset_index(drop=True) print("FULL Dataset: {}".format(df.shape)) print("TRAIN Dataset: {}".format(train_dataset.shape)) print("TEST Dataset: {}".format(test_dataset.shape)) # 为了迅速训练一下,我这里只截断数据集2000条走一遍流程 train_dataset = train_dataset[:2000] test_dataset = test_dataset[:100] training_set = Triage(train_dataset, tokenizer, MAX_LEN) testing_set = Triage(test_dataset, tokenizer, MAX_LEN)
使用 torch DataLoader 构造数据迭代器
train_params = {'batch_size': TRAIN_BATCH_SIZE, 'shuffle': True, 'num_workers': 0 } test_params = {'batch_size': VALID_BATCH_SIZE, 'shuffle': True, 'num_workers': 0 } training_loader = DataLoader(training_set, **train_params) testing_loader = DataLoader(testing_set, **test_params)
构造网络
在 huggingface.co/distilbert-…
下载 distilbert-base-uncased 所需文件并放在名为 “distilbert-base-uncased” 的文件夹下
class DistillBERTClass(torch.nn.Module): def __init__(self): super(DistillBERTClass, self).__init__() self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased") self.pre_classifier = torch.nn.Linear(768, 768) self.dropout = torch.nn.Dropout(0.3) self.classifier = torch.nn.Linear(768, 4) def forward(self, input_ids, attention_mask): output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask) hidden_state = output_1[0] # size 为 (batch size, max_len, 768) pooler = hidden_state[:, 0] # size 为 (batch size, 768) pooler = self.pre_classifier(pooler) pooler = torch.nn.ReLU()(pooler) pooler = self.dropout(pooler) output = self.classifier(pooler) return output model = DistillBERTClass() model.to(device)
设置损失函数和优化器
loss_function = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)
def calcuate_accu(big_idx, targets): n_correct = (big_idx==targets).sum().item() return n_correct
训练步骤
from tqdm import tqdm def train(epoch): model.train() for _,data in tqdm(enumerate(training_loader, 0)): ids = data['ids'].to(device, dtype = torch.long) mask = data['mask'].to(device, dtype = torch.long) targets = data['targets'].to(device, dtype = torch.long) outputs = model(ids, mask) loss = loss_function(outputs, targets) if _%100==0: print(f'Epoch: {epoch}, Loss: {loss.item()}') optimizer.zero_grad() loss.backward() optimizer.step() return
for epoch in range(EPOCHS): train(epoch)
验证模型
def valid(model, testing_loader): model.eval() n_correct = 0 tr_loss = 0 nb_tr_steps = 0 nb_tr_examples = 0 with torch.no_grad(): for _, data in enumerate(testing_loader, 0): ids = data['ids'].to(device, dtype = torch.long) mask = data['mask'].to(device, dtype = torch.long) targets = data['targets'].to(device, dtype = torch.long) outputs = model(ids, mask).squeeze() loss = loss_function(outputs, targets) tr_loss += loss.item() big_val, big_idx = torch.max(outputs.data, dim=1) n_correct += calcuate_accu(big_idx, targets) nb_tr_steps += 1 nb_tr_examples+=targets.size(0) epoch_loss = tr_loss/nb_tr_steps epoch_accu = (n_correct*100)/nb_tr_examples print(f"Validation Loss Epoch: {epoch_loss}") print(f"Validation Accuracy Epoch: {epoch_accu}") return epoch_accu acc = valid(model, testing_loader) print("Accuracy on test data = %0.2f%%" % acc) # Validation Loss Epoch: 1.2351238656044006 # Validation Accuracy Epoch: 58.0 # Accuracy on test data = 58.00%
保存模型和 tokenizer
output_model_file = 'models/pytorch_distilbert_news.bin' output_vocab_dir = 'models/vocab_distilbert/' model_to_save = model torch.save(model_to_save.state_dict(), output_model_file) tokenizer.save_pretrained(output_vocab_dir)
加载模型和 tokenizer
load_model = DistillBERTClass() load_model.load_state_dict(torch.load(output_model_file))
load_tokenizer = DistilBertTokenizer.from_pretrained(output_vocab_dir)
测试一下 tokenizer 是否加载成功
ori_test_inputs = tokenizer.encode_plus( "US open: Stocks fall after Fed official", None, add_special_tokens=True, max_length=MAX_LEN, pad_to_max_length=True, return_token_type_ids=True, truncation=True) load_test_inputs = load_tokenizer.encode_plus( "US open: Stocks fall after Fed official", None, add_special_tokens=True, max_length=MAX_LEN, pad_to_max_length=True, return_token_type_ids=True, truncation=True) load_test_inputs == ori_test_inputs
测试一下模型是否加载成功
load_model.eval() test_one_ids = torch.tensor([load_test_inputs['input_ids']], dtype=torch.long) test_one_mask = torch.tensor([load_test_inputs['attention_mask']], dtype=torch.long) res2 = load_model(test_one_ids, test_one_mask) res1 = model(test_one_ids, test_one_mask) res2 == res1
Be First to Comment