今日在写程序时,遇到了一个蜜汁 bug
,加载别人训练好的 ResNet18
,识别精度很低,只有 16%,但理论上而言应该有 92%,我也好奇那 80% 的准确率去哪里了。而程序和数据本身又无错误,所以来探究一下这是为什幺。
首先,网络结构和预训练的模型来自这里,这里声明一下,他提供的网络、参数都是没任何问题的,准确率低是我自己的原因。
错误程序
import torch import numpy as np from resnet18 import ResNet18 from torch.utils.data import Dataset, DataLoader from torch.autograd import Variable from torchvision import transforms class testdataset(Dataset): def __init__(self, data_path, label_path): # 模型是预训练好的,取后面 10000 个做测试 self.x_data = np.load(data_path) # 数据到 [0, 1] 之间 self.x_data = self.x_data / 255 self.x_data = self.x_data[50000:] self.y_data = np.load(label_path) self.y_data = self.y_data[50000:] def __getitem__(self, index): x_ = self.x_data[index] x_ = x_.transpose(2, 1, 0) y_ = self.y_data[index] return torch.from_numpy(x_), torch.from_numpy(y_) def __len__(self): return len(self.x_data) def _error(model, X, y): out = model(X) prediction = torch.argmax(out, 1) prediction = prediction.unsqueeze(1) correct = (prediction == y).sum().float() return correct def _eval(model, device, test_loader): model.eval() model.to(device) natural_err_total = 0 for data, target in test_loader: data, target = data.to(device, dtype=torch.float), target.to(device, dtype=torch.float) X, y = Variable(data, requires_grad=True), Variable(target) err_natural = _error(model, X, y) natural_err_total += err_natural print('acc: ', natural_err_total / 10000) if __name__ == "__main__": # 加载 resnet resnet = ResNet18() resnet_path = "resnet18_ckpt.pth" checkpoint = torch.load(resnet_path, map_location='cpu') print('loaded model...') # 这里只是为了对应模型参数 net_state = {} for key in checkpoint['net']: net_state[key[7:]] = checkpoint['net'][key] resnet.load_state_dict(net_state) print('set model...') data_path = "cifar10_data.npy" label_path = "cifar10_label.npy" test_data = testdataset(data_path=data_path, label_path=label_path) test_loader = DataLoader(test_data, batch_size=128) print('load data...') # 干净样本准确率 print('natural', end=', ') _eval(model=resnet, device='cpu', test_loader=test_loader)
在这样操作下,准确率只有 16.89%
,我也很奇怪是哪里错了。
正确程序
从师兄那里找到了一份正确的程序,准确率是 92.84%
。
import torch from resnet18 import ResNet18 from collections import OrderedDict from torch.utils.data import Dataset import numpy as np class TensorDataset(Dataset): """ """ def __init__(self, dataPath, labelPath): x = np.load(dataPath) x = x[50000:] / 255. x = x.astype("float32") data = x.transpose(0, 3, 1, 2) label = np.load(labelPath)[50000:] label = np.reshape(label, (data.shape[0], )) data, label = torch.from_numpy(data), torch.from_numpy(label) self.data_tensor = data self.target_tensor = label def __getitem__(self, index): return self.data_tensor[index], self.target_tensor[index] def __len__(self): return self.data_tensor.size(0) net = ResNet18() resnet_path = "resnet18_ckpt.pth" d = torch.load(resnet_path, map_location=torch.device('cpu'))['net'] d = OrderedDict([(k[7:], v) for (k, v) in d.items()]) net.load_state_dict(d) dataPath = "cifar10_data.npy" labelPath = "cifar10_label.npy" dataset = TensorDataset(dataPath, labelPath) dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0) total = 0 correct = 0 for batch_idx, (inputs, targets) in enumerate(dataloader): outputs = net(inputs) _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() acc = 100. * correct / total print(acc)
# acc 48 % # x_ = x_.transpose(2, 0, 1) # acc 16 % # x_ = x_.transpose(2, 1, 0)
目前已经定位到是 eval
和 数据增强的问题,明天再看。
https://www.zhihu.com/question/354742972
不是
reference
: https://github.com/laisimiao/classification-cifar10-pytorch
Be First to Comment