Press "Enter" to skip to content

[Pytorch系列-60]:循环神经网络 – 中文新闻文本分类详解-2-LSTM网络训练与评估代码详解

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

目录

 

第2章 代码准备 (Jupter)

 

3.2 定义构建数据集API

 

4.2 实例化模型并显示模型结构

 

4.3 初始化模型权重参数

 

5.3 边训练、边评估模型

 

第6章 在测试集上对模型进行评估

 

第1章 预备知识

 

1.1 业务概述

 

[Pytorch系列-59]:循环神经网络 – 中文新闻文本分类详解-1-业务目标分析与总体架构_文火冰糖(王文兵)的博客-CSDN博客 https://blog.csdn.net/HiWangWenBing/article/details/121756744

 

1.2 LSTM网络

 

(1)双向LSTM

 

 

备注:

 

本案例是双向的LSTM,因此隐藏层的输出,是两个方向输出的拼接。

 

因此全连接网络的输入是 2 * 隐藏层特征数。

 

(2) LSTM的层数

 

 

(3)隐藏的输出

 

 

只使用当前的隐层输出送入到全连接网络。

 

第2章 代码准备 (Jupter)

 

2.1 代码与数据集下载

 

2.2 导入库

 

import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from sklearn import metrics
import os
import torch
import numpy as np
import pickle as pkl
from tqdm import tqdm
import time
from datetime import timedelta

 

2.3 系统配置

 

(1)系统配置数据结构

 

class Config(object):
    """配置参数"""
    def __init__(self, dataset, embedding):
        self.model_name = 'TextRNN'
        #数据集路径
        self.train_path = dataset + '/data/train.txt'                                # 训练集
        self.dev_path = dataset + '/data/dev.txt'                                    # 验证集
        self.test_path = dataset + '/data/test.txt'                                  # 测试集
        
        #类别文件
        self.class_list = [x.strip() for x in open(
            dataset + '/data/class.txt').readlines()]                                # 类别名单
        
        #单词表:是单词与其索引的对应表
        self.vocab_path = dataset + '/data/vocab.pkl'                                # 词表
        
        # 词向量表: 是索引与向量编码的对应表
        self.embedding_pretrained = torch.tensor(
            np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\
            if embedding != 'random' else None                                       # 预训练词向量
        self.embed = self.embedding_pretrained.size(1)\
            if self.embedding_pretrained is not None else 300           # 字向量维度, 若使用了预训练词向量,则维度统一
    
        # 训练数据保存
        self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt'        # 模型训练结果
        self.log_path = dataset + '/log/' + self.model_name
        
        # GPU or CPU
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   # 设备
        
        # 模型参数
        self.hidden_size = 128                                          # lstm隐藏层
        self.num_layers = 2                                             # lstm层数
        
        # 训练时的参数
        self.dropout = 0.5                                              # 随机失活
        self.require_improvement = 1000                                 # 若超过1000batch效果还没提升,则提前结束训练
        self.num_classes = len(self.class_list)                         # 类别数
        self.n_vocab = 0                                                # 词表大小,在运行时赋值
        self.num_epochs = 20                                            # epoch数
        self.batch_size = 64                                           # mini-batch大小
        self.pad_size = 32                                              # 每句话处理成的长度(短填长切)
        self.learning_rate = 1e-3                                       # 学习率

 

(2)实例化配置对象

 

# 数据集目录
dataset = 'THUCNews'  
# 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random
embedding = 'embedding_SougouNews.npz'
#通过空格分隔的英文单词还是中文的字符
word = False
#初始化配置实例
config = Config(dataset, embedding)
#显示配置信息
print(config.device)
print(config.embed)
print(config.embedding_pretrained)

 

第3章 构建数据集

 

3.1 构建单词表API

 

MAX_VOCAB_SIZE = 10000
#新闻标题的填充,固定输入长度为32
UNK, PAD = '<UNK>', '<PAD>'
# 单词表不是词向量表,而是单词与其索引对应关系的字典表。
# 从指定单词表中读取词向量表:
# file_path:单词表的路径
# tokenizer:分词器,与英文不同,中文的单词是仅仅相邻的,中间没有空格,因此需要分词器进行分词。
# max_size:单词的最大数量
# min_freq:单词表排序时的参考词频
def build_vocab(file_path, tokenizer, max_size, min_freq):
    # 单词表是一个字典
    vocab_dic = {}
    with open(file_path, 'r', encoding='UTF-8') as f:
        # 通过tqdm从单词表中读取一行单词,tqdm能够显示进度条
        for line in tqdm(f):
            # 移除字符串头尾指定的字符(默认为空格或换行符)或字符序列
            lin = line.strip()
            if not lin:
                #空行
                continue
            #按照空格或table键,把字符转换成短语列表
            content = lin.split('\t')[0]
            
            # 从列表中提取一个个独立的中文单词(即中文字)
            for word in tokenizer(content):
                # 构建单词字典表
                vocab_dic[word] = vocab_dic.get(word, 0) + 1
        
        #对单词表进行排序
        vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[:max_size]
        
        #还原成字典
        vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}
        
        #使用UNK填充单词表的尾部
        #  ,'<UNK>': 4760, '<PAD>': 4761}
        vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1})
    return vocab_dic

 

3.2 定义构建数据集API

 

def build_dataset(config, ues_word):
    print("构建单词表")
    # 指定分词器
    print("ues_word=",ues_word)
    if ues_word:
        tokenizer = lambda x: x.split(' ')  # 以空格隔开,word-level
    else:
        tokenizer = lambda x: [y for y in x]  # char-level =》适合中文
    
    # load单词表
    if os.path.exists(config.vocab_path):
        # 如果有现成的单词表,则使用已有的单词表(单词与索引的字典)
        print("使用已有的单词表:", config.vocab_path)
        vocab = pkl.load(open(config.vocab_path, 'rb'))
    else:
        # 如果没有现成的单词表,则基于训练集,构建一个新的词表
        print("基于训练集,新构建单词表:", config.train_path)
        vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
        pkl.dump(vocab, open(config.vocab_path, 'wb'))
    
    print(f"Vocab size: {len(vocab)}")
    print("构建数据集")
    # 定义load和转换数据集的函数
    # 固定长度为32。
    def load_dataset(path, pad_size=32):
        contents = []
        print("数据集:", path)
        with open(path, 'r', encoding='UTF-8') as f:
            # 读取一行文件,并显示进度条
            for line in tqdm(f):
                #去掉头尾标识符
                lin = line.strip()
                if not lin:
                    # 跳过空行
                    continue
                
                #通过空格分离单词和标签
                content, label = lin.split('\t')
                
                words_line = []
                token = tokenizer(content)
                seq_len = len(token)
                
                # 根据填充单词,确定有效字符长度:seq_len
                if pad_size:
                    if len(token) < pad_size:
                        token.extend([vocab.get(PAD)] * (pad_size - len(token)))
                    else:
                        token = token[:pad_size]
                        seq_len = pad_size
                
                #构建一个个样本数据
                for word in token:
                    # 从单词表中获取每个单词对应的索引index,并添加到文字样本对应的列表中
                    # words_line:存放当个样本数据(单词的index列表)
                    words_line.append(vocab.get(word, vocab.get(UNK)))
                
                #contents:存放所有样本数据(单词的index列表)
                contents.append((words_line, int(label), seq_len))
        return contents  # [([...], 0), ([...], 1), ...]
    
    # load训练数据集
    train = load_dataset(config.train_path, config.pad_size)
    
    # load 验证数据集
    dev = load_dataset(config.dev_path, config.pad_size)
    
    # load 测试数据集
    test = load_dataset(config.test_path, config.pad_size)
    
    return vocab, train, dev, test

 

3.3 构建三大数据集

 

(1)构建数据集

 

def get_time_dif(start_time):
    """获取已使用时间"""
    end_time = time.time()
    time_dif = end_time - start_time
    return timedelta(seconds=int(round(time_dif)))
start_time = time.time()
#构建三大数据集
print("Loading data...")
vocab, train_data, dev_data, test_data = build_dataset(config, word)
# 更新词向量的长度
config.n_vocab = len(vocab)
time_dif = get_time_dif(start_time)
print("Time usage:", time_dif)

 

(2)显示单词表

 

print(vocab)

 

{' ': 0, '0': 1, '1': 2, '2': 3, ':': 4, '大': 5, '国': 6, '图': 7, '(': 8, ')': 9, '3': 10, '人': 11, '年': 12, '5': 13, '中': 14, '新': 15, '9': 16, '生': 17, '金': 18, '高': 19, '《': 20, '》': 21, '4': 22, '上': 23, '8': 24, '不': 25, '考': 26, '一': 27, '6': 28, '日': 29, '元': 30, '开': 31, '美': 32, '价': 33, '发': 34, '学': 35, '公': 36, '成': 37, '月': 38, '将': 39, '万': 40, '7': 41, '基': 42, '市': 43, '出': 44, '子': 45, '行': 46, '机': 47, '业': 48, '被': 49, '家': 50, '股': 51, '的': 52, '在': 53, '网': 54, '女': 55, '期': 56, '平': 57, '房': 58, '名': 59, '三': 60, '-': 61, '会': 62, '地': 63, '场': 64, '全': 65, '小': 66, '现': 67, '有': 68, '分': 69, '后': 70, '称': 71, '组': 72, '为': 73, '下': 74, '盘': 75, '最': 76, '“': 7

 

……..

 

737, '恫': 4738, '诣': 4739, '叁': 4740, '氮': 4741, '曳': 4742, '膑': 4743, '峦': 4744, '攫': 4745, '鹄': 4746, '啄': 4747, '憩': 4748, '鞑': 4749, '垠': 4750, '鹕': 4751, '鄞': 4752, '呸': 4753, 'V': 4754, '玷': 4755, '瘁': 4756, '蚱': 4757, '§': 4758, '霎': 4759, '<UNK>': 4760, '<PAD>': 4761}

 

(3)显示训练数据集

 

# 训练集索引是单词的索引
# 样本:
# 第一组数: 输入:32个单词序列的索引,文本新闻标题样本,转换成其索引,固定长度为32个单词,不足填充=》4760:PAD
# 第二个数:分类的类别
# 第三个数:有效字符的长度(不包括填充字符)
# 训练集输入数据的长度(包括填充字符)
print(len(train_data[0][0]))
#中华女子学院:本科层次仅1专业招男生3
print(train_data[0])
#两天价网站背后重重迷雾:做个网站究竟要多少钱4
print(train_data[1])

 

32
([14, 125, 55, 45, 35, 307, 4, 81, 161, 941, 258, 494, 2, 175, 48, 145, 97, 17, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760], 3, 18)
([135, 80, 33, 54, 505, 1032, 70, 95, 95, 681, 2288, 4, 486, 179, 54, 505, 626, 1156, 180, 115, 421, 561, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760], 4, 22)

 

(3)显示验证数据集

 

print(dev_data[0])
print(dev_data[1])

 

([173, 714, 3, 186, 1844, 889, 0, 2641, 80, 2061, 416, 478, 382, 5, 308, 15, 1264, 1344, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760], 8, 18)
([28, 1, 12, 567, 1371, 31, 365, 899, 846, 1300, 1095, 256, 1311, 8, 72, 7, 9, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760], 5, 17)

 

(4)显示测试数据集

 

print(test_data[0])
print(test_data[1])

 

([1393, 686, 1350, 656, 110, 232, 1138, 0, 1, 24, 12, 26, 216, 1533, 56, 123, 434, 270, 742, 65, 112, 236, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760], 3, 22)
([14, 6, 11, 156, 36, 211, 5, 35, 3, 1, 2, 3, 12, 830, 324, 216, 626, 17, 334, 291, 461, 659, 334, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760, 4760], 3, 23)

 

3.4 构建迭代器

 

(1)定义类或函数

 

# 迭代器类
class DatasetIterater(object):
    def __init__(self, batches, batch_size, device):
        self.batch_size = batch_size
        self.batches = batches
        self.n_batches = len(batches) // batch_size
        self.residue = False  # 记录batch数量是否为整数
        if len(batches) % self.n_batches != 0:
            self.residue = True
        self.index = 0
        self.device = device
    def _to_tensor(self, datas):
        x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
        y = torch.LongTensor([_[1] for _ in datas]).to(self.device)
        # pad前的长度(超过pad_size的设为pad_size)
        seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
        return (x, seq_len), y
    # 迭代函数
    def __next__(self):
        if self.residue and self.index == self.n_batches:
            batches = self.batches[self.index * self.batch_size: len(self.batches)]
            self.index += 1
            batches = self._to_tensor(batches)
            return batches
        elif self.index > self.n_batches:
            self.index = 0
            raise StopIteration
        else:
            batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size]
            self.index += 1
            batches = self._to_tensor(batches)
            return batches
    def __iter__(self):
        return self
    def __len__(self):
        if self.residue:
            return self.n_batches + 1
        else:
            return self.n_batches
# 构建迭代器的API
def build_iterator(dataset, config):
    iter = DatasetIterater(dataset, config.batch_size, config.device)
    return iter

 

(2)实例化

 

# 训练集loader
train_iter = build_iterator(train_data, config)
# 验证集loader
dev_iter = build_iterator(dev_data, config)
# 测试集loader
test_iter = build_iterator(test_data, config)
print(train_iter)

 

<__main__.DatasetIterater object at 0x0000022804DAFD30>

 

第4章 构建模型:LSTM

 

4.1 定义模型类

 

class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        # 词向量网络
        if config.embedding_pretrained is not None:
            # 使用不需要重新训练的、预训练好的词向量,加快训练速度、提升性能
            self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)
        else:
            # 使用新定义的可训练的词词向量
            self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)
        
        # LSTM网络
        # config.embed:词向量的输出长度=300,它是LSTM的输入
        # config.hidden_size:隐藏层输出特征的长度
        # config.num_layers:隐藏层的层数
        # bidirectional:双向网络
        # batch_first: [batch_size, seq_len, embeding]
        # dropout:随机丢弃
        self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers,
                            bidirectional=True, batch_first=True, dropout=config.dropout)
        
        # 全连接分类网络
        # 使用隐层当前时刻的输出作为全连接的输入。
        # config.hidden_size * 2:双向LSTM的输出是隐层特征输出的2倍
        self.fc = nn.Linear(config.hidden_size * 2, config.num_classes)
    def forward(self, x):
        x, _ = x
        out = self.embedding(x)  # [batch_size, seq_len, embeding]=[128, 32, 300]
        out, _ = self.lstm(out)
        out = self.fc(out[:, -1, :])  # 句子最后时刻的 hidden state
        return out

 

4.2 实例化模型并显示模型结构

 

# 构建模型
#设定随机种子,确保每次随机初始化的结果是一样的
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
torch.backends.cudnn.deterministic = True  # 保证每次结果一样
model_name = "TextRNN"
# 创建模型实例
model = Model(config).to(config.device)
#显示网络参数
for name, w in model.named_parameters():
    print(name)
    
print(model.parameters)

 

embedding.weight
lstm.weight_ih_l0
lstm.weight_hh_l0
lstm.bias_ih_l0
lstm.bias_hh_l0
lstm.weight_ih_l0_reverse
lstm.weight_hh_l0_reverse
lstm.bias_ih_l0_reverse
lstm.bias_hh_l0_reverse
lstm.weight_ih_l1
lstm.weight_hh_l1
lstm.bias_ih_l1
lstm.bias_hh_l1
lstm.weight_ih_l1_reverse
lstm.weight_hh_l1_reverse
lstm.bias_ih_l1_reverse
lstm.bias_hh_l1_reverse
fc.weight
fc.bias
<bound method Module.parameters of Model(
  (embedding): Embedding(4762, 300)
  (lstm): LSTM(300, 128, num_layers=2, batch_first=True, dropout=0.5, bidirectional=True)
  (fc): Linear(in_features=256, out_features=10, bias=True)
)>

 

4.3 初始化模型权重参数

 

# 权重初始化:不同的初始化方法,导致精确性和收敛时间不同
# 默认xavier
# xavier:“Xavier”初始化方法是一种很有效的神经网络初始化方法
# kaiming:何凯明初始化
# normal_: 正态分布初始化
def init_network(model, method='xavier', exclude='embedding', seed=123):
    for name, w in model.named_parameters():
        if exclude not in name:
            if 'weight' in name:
                if method == 'xavier':
                    nn.init.xavier_normal_(w)
                elif method == 'kaiming':
                    nn.init.kaiming_normal_(w)
                else:
                    nn.init.normal_(w)
            elif 'bias' in name:
                nn.init.constant_(w, 0)
            else:
                pass
#初始化网络
init_network(model)

 

第5章 模型训练、评估

 

5.1 模型评估方法

 

# 模型评估方法
def evaluate(config, model, data_iter, test=False):
    # 设置在评估模式
    model.eval()
    loss_total = 0
    predict_all = np.array([], dtype=int)
    labels_all = np.array([], dtype=int)
    
    # 不进行梯度更新
    with torch.no_grad():
        # 数据集迭代
        for texts, labels in data_iter:
            
            # 模型预测输出
            outputs = model(texts)
            
            # 计算当前的loss
            loss = F.cross_entropy(outputs, labels)
            loss_total += loss
            
            # 计算当前的精度
            labels = labels.data.cpu().numpy()
            predic = torch.max(outputs.data, 1)[1].cpu().numpy()
            
            # 记录当前的label的数目
            labels_all = np.append(labels_all, labels)
            
            # 记录当前正确预测的数目
            predict_all = np.append(predict_all, predic)
    
    # 计算整个数据集上的平均精度
    acc = metrics.accuracy_score(labels_all, predict_all)
    
    if test:
        report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
        confusion = metrics.confusion_matrix(labels_all, predict_all)
        return acc, loss_total / len(data_iter), report, confusion
    
    # 返回整个数据集上的平均精度与平均loss
    return acc, loss_total / len(data_iter)

 

5.2 模型训练方法

 

# 训练方法
writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
def train(config, model, train_iter, dev_iter, writer):
    start_time = time.time()
    
    # 设定在模式下
    model.train()
    
    #设定优化器
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    # 学习率指数衰减,每次epoch:学习率 = gamma * 学习率
    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
    total_batch = 0  # 记录进行到多少batch,一个训练集包含多个batch
    
    #记录当前最好的loss值
    dev_best_loss = float('inf')
    last_improve = 0  # 记录上次验证集loss下降的batch数
    flag = False  # 记录是否很久没有效果提升
    
    # 启动一个SummaryWriter对象,用于 记录训练过程
    writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
    
    #开始训练
    for epoch in range(config.num_epochs):
        print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
        
        #自动调整学习率
        #scheduler.step() # 学习率衰减
        
        # 迭代数据集
        for i, (trains, labels) in enumerate(train_iter):
            #print (trains[0].shape)
            
            #获取模型输出
            outputs = model(trains)
            
            #复位模型梯度
            model.zero_grad()
            
            # 计算模型loss
            loss = F.cross_entropy(outputs, labels)
            
            # 根据loss计算梯度
            loss.backward()
            
            # 反向迭代,更新W参数
            optimizer.step()
            
            # 对迭代进行测试与评估
            # 每100次迭代输出,在训练集和验证集上的评估一次效果
            if total_batch % 100 == 0:
                # 获取训练集上的精度
                true = labels.data.cpu()
                predic = torch.max(outputs.data, 1)[1].cpu()
                train_acc = metrics.accuracy_score(true, predic)
                
                # 获取验证集上的精度
                dev_acc, dev_loss = evaluate(config, model, dev_iter, test=False)
                if dev_loss < dev_best_loss:
                    dev_best_loss = dev_loss
                    # 保存当前精度更高时候的模型
                    torch.save(model.state_dict(), config.save_path)
                    improve = '*'
                    
                    # 记录模型更新时的batch数
                    last_improve = total_batch
                else:
                    improve = ''
                
                # 打印log信息
                time_dif = get_time_dif(start_time)
                msg = 'Iter: {0:>6},  Train Loss: {1:>5.2},  Train Acc: {2:>6.2%},  Val Loss: {3:>5.2},  Val Acc: {4:>6.2%},  Time: {5} {6}'
                print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))
                writer.add_scalar("loss/train", loss.item(), total_batch)
                writer.add_scalar("loss/dev", dev_loss, total_batch)
                writer.add_scalar("acc/train", train_acc, total_batch)
                writer.add_scalar("acc/dev", dev_acc, total_batch)
                
                #重新进入训练模式
                model.train()
            
            # batch数++
            total_batch += 1
            
            # 如果连续迭代后,精度没有得到进一步的提升,当次数得到一定的设定值后,自动停止迭代。
            # total_batch:连续进行了多少次batch
            # last_improve:记录模型更新时的batch数
            # config.require_improvement
            if total_batch - last_improve > config.require_improvement:
                # 验证集loss超过1000batch 没下降,结束训练
                print("No optimization for a long time, auto-stopping...")
                print("total_batch=", total_batch)
                print("last_improve=", last_improve)
                print("require_improvement=", config.require_improvement)
                flag = True
                break
        if flag:
            break
    writer.close()

 

5.3 边训练、边评估模型

 

在训练集上训练,在验证集上评估

 

# 一边训练,一边评估
train(config, model, train_iter, dev_iter, writer)

 

Epoch [1/20]
Iter:      0,  Train Loss:   2.3,  Train Acc: 17.19%,  Val Loss:   2.3,  Val Acc: 10.00%,  Time: 0:00:01 *
Iter:    100,  Train Loss:   1.7,  Train Acc: 34.38%,  Val Loss:   1.8,  Val Acc: 28.56%,  Time: 0:00:02 *
Iter:    200,  Train Loss:   1.3,  Train Acc: 56.25%,  Val Loss:   1.4,  Val Acc: 46.30%,  Time: 0:00:03 *
Iter:    300,  Train Loss:   1.1,  Train Acc: 60.94%,  Val Loss:   1.1,  Val Acc: 61.21%,  Time: 0:00:04 *
.............................................................................
Iter:   5300,  Train Loss:   0.3,  Train Acc: 90.62%,  Val Loss:  0.34,  Val Acc: 88.94%,  Time: 0:01:04 
Iter:   5400,  Train Loss:  0.33,  Train Acc: 89.06%,  Val Loss:  0.32,  Val Acc: 89.61%,  Time: 0:01:05 
Iter:   5500,  Train Loss:  0.43,  Train Acc: 84.38%,  Val Loss:  0.34,  Val Acc: 89.69%,  Time: 0:01:07 
Iter:   5600,  Train Loss:  0.31,  Train Acc: 89.06%,  Val Loss:  0.32,  Val Acc: 89.91%,  Time: 0:01:08 
Epoch [3/20]
Iter:   5700,  Train Loss:  0.22,  Train Acc: 90.62%,  Val Loss:  0.33,  Val Acc: 89.70%,  Time: 0:01:09 
No optimization for a long time, auto-stopping...
total_batch= 5701
last_improve= 4700
require_improvement= 1000

 

第6章 在测试集上对模型进行评估

 

6.1 测试方法的定义

 

# 在测试集上对模型进行评估
def test(config, model, test_iter):
    # test
    # 获取保存的最佳精度的模型
    model.load_state_dict(torch.load(config.save_path))
    
    # 进入评估模式
    model.eval()
    
    start_time = time.time()
    
    # 测试测试集进行评估
    test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
    
    # 打印测试集的评估结果
    # 测试集的loss和精度
    msg = 'Test Loss: {0:>5.2},  Test Acc: {1:>6.2%}'
    print(msg.format(test_loss, test_acc))
    
    # 打印准确率、召回率、F1-Score的分数
    print("Precision, Recall and F1-Score...")
    print(test_report)
    
    # 打印混淆矩阵
    print("Confusion Matrix...")
    print(test_confusion)
    
    
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)

 

备注:

 

至于模型的评分指标:Loss、accuracy、Precision, Recall and F1-Score,请参看相关文章。

 

6.2 开始测试

 

# 对训练好的模型进行测试
test(config, model, test_iter)

 

Test Loss:  0.31,  Test Acc: 89.73%
Precision, Recall and F1-Score...
               precision    recall  f1-score   support
      finance     0.9240    0.8630    0.8925      1000
       realty     0.8796    0.9280    0.9032      1000
       stocks     0.8547    0.8060    0.8296      1000
    education     0.9477    0.9420    0.9448      1000
      science     0.8188    0.8540    0.8360      1000
      society     0.8817    0.9020    0.8917      1000
     politics     0.8645    0.8680    0.8663      1000
       sports     0.9662    0.9730    0.9696      1000
         game     0.9351    0.9080    0.9214      1000
entertainment     0.9055    0.9290    0.9171      1000
     accuracy                         0.8973     10000
    macro avg     0.8978    0.8973    0.8972     10000
 weighted avg     0.8978    0.8973    0.8972     10000
Confusion Matrix...
[[863  24  65   2  18  11  10   2   2   3]
 [ 10 928  14   0  13  16   3   1   3  12]
 [ 38  44 806   1  54   2  41   2   9   3]
 [  0   4   2 942   8  14   9   0   3  18]
 [  4   8  26   8 854  22  25   1  32  20]
 [  3  21   2  18   2 902  31   2   3  16]
 [ 11  15  19  11  24  37 868   4   1  10]
 [  1   2   1   2   4   3   6 973   1   7]
 [  1   2   6   5  55   7   4   4 908   8]
 [  3   7   2   5  11   9   7  18   9 929]]
Time usage: 0:00:00

 

作者主页(文火冰糖的硅基工坊): 文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

Be First to Comment

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注