Press "Enter" to skip to content

神经网络中的数据问题

今日在写程序时,遇到了一个蜜汁 bug
,加载别人训练好的 ResNet18
,识别精度很低,只有 16%,但理论上而言应该有 92%,我也好奇那 80% 的准确率去哪里了。而程序和数据本身又无错误,所以来探究一下这是为什幺。

 

首先,网络结构和预训练的模型来自这里,这里声明一下,他提供的网络、参数都是没任何问题的,准确率低是我自己的原因。

 

错误程序

 

import torch
import numpy as np
from resnet18 import ResNet18
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from torchvision import transforms

class testdataset(Dataset):
    def __init__(self, data_path, label_path):
        # 模型是预训练好的,取后面 10000 个做测试
        self.x_data = np.load(data_path)
        # 数据到 [0, 1] 之间
        self.x_data = self.x_data / 255
        self.x_data = self.x_data[50000:]
        self.y_data = np.load(label_path)
        self.y_data = self.y_data[50000:]
    def __getitem__(self, index):
        x_ = self.x_data[index]
        x_ = x_.transpose(2, 1, 0)
        y_ = self.y_data[index]
        return torch.from_numpy(x_), torch.from_numpy(y_)
    def __len__(self):
        return len(self.x_data)

def _error(model, X, y):
    out = model(X)
    prediction = torch.argmax(out, 1)
    prediction = prediction.unsqueeze(1)
    correct = (prediction == y).sum().float()
    return correct

def _eval(model, device, test_loader):
    model.eval()
    model.to(device)
    natural_err_total = 0
    for data, target in test_loader:
        data, target = data.to(device,
                               dtype=torch.float), target.to(device,
                                                             dtype=torch.float)
        X, y = Variable(data, requires_grad=True), Variable(target)
        err_natural = _error(model, X, y)
        natural_err_total += err_natural
    print('acc: ', natural_err_total / 10000)

if __name__ == "__main__":
    # 加载 resnet
    resnet = ResNet18()
    resnet_path = "resnet18_ckpt.pth"
    checkpoint = torch.load(resnet_path, map_location='cpu')
    print('loaded model...')
    # 这里只是为了对应模型参数
    net_state = {}
    for key in checkpoint['net']:
        net_state[key[7:]] = checkpoint['net'][key]
    resnet.load_state_dict(net_state)
    print('set model...')
    data_path = "cifar10_data.npy"
    label_path = "cifar10_label.npy"
    test_data = testdataset(data_path=data_path, label_path=label_path)
    test_loader = DataLoader(test_data, batch_size=128)
    print('load data...')
    # 干净样本准确率
    print('natural', end=', ')
    _eval(model=resnet, device='cpu', test_loader=test_loader)

 

在这样操作下,准确率只有 16.89%
,我也很奇怪是哪里错了。

 

正确程序

 

从师兄那里找到了一份正确的程序,准确率是 92.84%

 

import torch
from resnet18 import ResNet18
from collections import OrderedDict
from torch.utils.data import Dataset
import numpy as np

class TensorDataset(Dataset):
    """
    """
    def __init__(self, dataPath, labelPath):
        x = np.load(dataPath)
        x = x[50000:] / 255.
        x = x.astype("float32")
        data = x.transpose(0, 3, 1, 2)
        label = np.load(labelPath)[50000:]
        label = np.reshape(label, (data.shape[0], ))
        data, label = torch.from_numpy(data), torch.from_numpy(label)
        self.data_tensor = data
        self.target_tensor = label
    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]
    def __len__(self):
        return self.data_tensor.size(0)

net = ResNet18()
resnet_path = "resnet18_ckpt.pth"
d = torch.load(resnet_path, map_location=torch.device('cpu'))['net']
d = OrderedDict([(k[7:], v) for (k, v) in d.items()])
net.load_state_dict(d)
dataPath = "cifar10_data.npy"
labelPath = "cifar10_label.npy"
dataset = TensorDataset(dataPath, labelPath)
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=64,
                                         shuffle=True,
                                         num_workers=0)
total = 0
correct = 0
for batch_idx, (inputs, targets) in enumerate(dataloader):
    outputs = net(inputs)
    _, predicted = outputs.max(1)
    total += targets.size(0)
    correct += predicted.eq(targets).sum().item()
acc = 100. * correct / total
print(acc)

 

# acc 48 %
# x_ = x_.transpose(2, 0, 1)
# acc 16 %
# x_ = x_.transpose(2, 1, 0)

 

目前已经定位到是 eval
和 数据增强的问题,明天再看。

https://www.zhihu.com/question/354742972
不是

https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/7
待查阅

reference

 

https://github.com/laisimiao/classification-cifar10-pytorch

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注