## 评估函数的建立

IMDB数据集是一个典型的2分类数据集。为此，我们使用准确率作为评估指标，该函数的定义如下：

```def binary_accuracy(preds, y):
"""
Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
"""
#round predictions to the closest integer
rounded_preds = torch.round(torch.sigmoid(preds))
correct = (rounded_preds == y).float() #convert into float for division
acc = correct.sum() / len(correct)
return acc```

## 模型训练

loss进行反向传播将梯度进行回传

```def train(model, iterator, optimizer, criterion):

epoch_loss = 0
epoch_acc = 0

model.train()

for batch in iterator:
text, text_lengths = batch.text
predictions = model(text, text_lengths).squeeze(1)
loss = criterion(predictions, batch.label)
acc = binary_accuracy(predictions, batch.label)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_acc += acc.item()

return epoch_loss / len(iterator), epoch_acc / len(iterator)```

## 模型的验证

```def evaluate(model, iterator, criterion):

epoch_loss = 0
epoch_acc = 0

model.eval()

for batch in iterator:
text, text_lengths = batch.text
predictions = model(text, text_lengths).squeeze(1)
loss = criterion(predictions, batch.label)
acc = binary_accuracy(predictions, batch.label)
epoch_loss += loss.item()
epoch_acc += acc.item()

return epoch_loss / len(iterator), epoch_acc / len(iterator)```

## 训练与评估

```N_EPOCHS = 5
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):
start_time = time.time()

train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)

end_time = time.time()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)

if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'model.pt')

print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')```

```Epoch: 01 | Epoch Time: 1m 50s
Train Loss: 0.558 | Train Acc: 70.57%
Val. Loss: 0.444 |  Val. Acc: 79.54%
Epoch: 02 | Epoch Time: 1m 50s
Train Loss: 0.393 | Train Acc: 82.70%
Val. Loss: 0.383 |  Val. Acc: 83.21%
Epoch: 03 | Epoch Time: 1m 50s
Train Loss: 0.287 | Train Acc: 88.10%
Val. Loss: 0.300 |  Val. Acc: 88.08%
Epoch: 04 | Epoch Time: 1m 50s
Train Loss: 0.161 | Train Acc: 94.26%
Val. Loss: 0.314 |  Val. Acc: 87.84%
Epoch: 05 | Epoch Time: 1m 50s
Train Loss: 0.122 | Train Acc: 95.53%
Val. Loss: 0.367 |  Val. Acc: 87.17%```

```model.load_state_dict(torch.load('model.pt'))
test_loss, test_acc = evaluate(model, test_iterator, criterion)
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')```

`Test Loss: 0.321 | Test Acc: 87.01%`

……