Press "Enter" to skip to content

0802-编程实战_猫和狗二分类_深度学习项目架构

0802-编程实战_猫和狗二分类_深度学习项目架构

 

目录

pytorch完整教程目录: https://www.cnblogs.com/nickchen121/p/14662511.html

 

一、比赛介绍

 

接下来我们将通过 pytorch 完成 Kaggle 上的经典比赛: Dogs vs. Cats

 

Dogs vs. Cats 是一个传统的二分类问题,它的训练集包含 25000 张图片,这些图片都放在同一个文件夹中,命名格式为 <category>.<num>.jpg
,例如 cat.10000.jpg
dog.100.jpg
,测试集包含 12500 张图片,命名为 <num>.jpg
,例如 1000.jpg

 

参赛者需要根据训练集的图片训练模型,并在测试集上进行预测,输出它是狗的概率。最后提交的 csv 文件如下,第一列是图片的 <num>
,第二列是图片为狗的概率。

idlabel
100010.889
100020.01

 

二、数据加载

 

数据的相关处理主要保存在 data/dataset.py
中。

 

关于数据加载,之前提过,基本原理就是先使用 Dataset 封装数据集,再使用 Dataloader 实现数据并行加载。

 

Kaggle 提供的数据包括训练集和测试集,但是在我们使用的时候,还需要从训练集中抽取一部分作为验证集。

 

对于上述所说的三个数据集,虽然它们的相应操作不太一样,但是如果专门写出三个 Dataset,则会显得复杂并冗余,因此在这里通过添加一些判断来区分三者。比如我们希望对训练集做一些数据增强处理,如随机裁剪、随机翻转、加噪声等,但是对于验证集和测试集则不需要。

 

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Coding by https://www.cnblogs.com/nickchen121/
# Datatime:2021/5/3 10:15
# Filename:dataset.py
# Toolby: PyCharm
import os
from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as T

class DogCat(data.Dataset):
    def __init__(self, root, transforms=None, train=True, test=False):
        """
        目标:获取所有图片地址,并根据训练、验证、测试划分数据
        """
        self.test = test  # 获取测试集
        imgs = [os.path.join(root, img)
                for img in os.listdir(root)]  # 拼接所有图片路径,路径地址如下所示
        """
        test1: data/test1/8973.jpg
        train: data/train/cat.10004.jpg
        """
        # 区分数据集是否为测试集,并对数据集的图片进行排序
        if self.test:
            imgs = sorted(
                imgs,
                key=lambda x: int(x.split('.')[-2].split('/')[-1]))  # 切割出 8973
        else:
            imgs = sorted(imgs,
                          key=lambda x: int(x.split('.')[-2]))  # 切割出 10004
        # 划分训练、验证集,验证:训练 = 3:7
        imgs_num = len(imgs)
        if self.test:
            self.imgs = imgs
        elif train:
            self.imgs = imgs[:int(0.7 * imgs_num)]  # 训练集来自数据集的前 70%
        else:
            self.imgs = imgs[int(0.7 * imgs_num):]
        # 数据转换操作,测试验证和训练的数据转换有所区别
        if transforms is None:
            # Normalize给定均值:(R,G,B) 方差:(R,G,B),将会把Tensor正则化
            normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
            # 测试集和验证集
            if self.test or not train:
                self.transforms = T.Compose([
                    T.Scale(224),  # 让图片统一大小为:224*224
                    T.CenterCrop(224),  # 中心切割
                    T.ToTensor(),
                    normalize
                ])
            # 训练集
            else:
                self.transforms = T.Compose([
                    T.Scale(256),  # 让图片统一大小为:256*256
                    T.RandomSizedCrop(224),  # 随机切割图片后,resize成给定的大小 224*224
                    T.RandomHorizontalFlip(),  # 一半的概率翻转,一半的概率不翻转
                    T.ToTensor(),
                    normalize
                ])
    def __getitem__(self, index):
        """
        返回一张图片的数据
        如果是测试集,没有图片 id,如 8973.jpg 返回 8973
        test1: data/test1/8973.jpg
        train: data/train/cat.10004.jpg
        """
        img_path = self.imgs[index]
        if self.test:
            label = self.imgs[index].split('.')[-2]  # type:str # 切割出 8973.jpg
            label = int(label.split('/')[-1])  # 切割出 8973
        else:
            label = 1 if 'dog' in img_path.split(
                '/')[-1] else 0  # 切割出 cat.10004.jpg,通过判断对图片增加标签
        data = Image.open(img_path)
        data = self.transforms(data)  # 对图片进行处理
        return data, label
    def __len__(self):
        """
        返回数据集中所有图片的个数
        """
        return len(self.imgs)
# train_dataset = DogCat(opt.train_data_root, train=True)  # opt 是未来会讲到的配置对象
# trainloader = DataLoader(train_dataset,
#                          batch_size=opt.batch_size,
#                          shuffle=True,
#                          num_workers=opt.num_workers)
# 
# for ii, (data, label) in enumerate(trainloader):
#     train()

 

上述代码中我们需要注意三个点:

 

__getitem__

 

三、模型定义

 

模型的定义主要保存在 models 目录下,其中 BasicModule 是对 nn.Module
的简易封装,提供快速加载和保存模型的接口。

 

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Coding by https://www.cnblogs.com/nickchen121/
# Datatime:2021/5/3 10:22
# Filename:BasicModule.py
# Toolby: PyCharm
import time
import torch as t

class BasicModule(t.nn.Module):
    """
    封装了 nn.Module,主要提供 save 和 load 两个方法
    """
    def __init__(self):
        super(BasicModule, self).__init__()
        self.model_name = str(type(self))  # 模型的默认名字
    def load(self, path):
        """
        可加载指定路径的模型
        :param path:
        :return:
        """
        self.load_state_dict(t.load(path))
    def save(self, name=None):
        """
        保存模型,默认使用“模型名字+时间”作为文件名,
        如 AlexNet_0710_23:57:29.pth
        :param name:
        :return:
        """
        if name is None:
            prefix = 'checkpoints/' + self.model_name + '.'
            name = time.strftime(prefix + '%m%d_%H:%M:%S.pth')
        t.save(self.state_dict(), name)
        return name

 

在实际使用中,直接调用 model.save()
以及 model.load(opt.load_path)
即可。

 

其他自定义模型一般继承 BasicModule,然后实现自己的模型。由于实现了 AlexNet 和 ResNet34,在 models/__init__.py
中,可以写下下述代码:

 

from .AlexNet import AlexNet
from .ResNet34 import ResNet34

 

这样主函数中就可以写:

 

from models import AlexNet
# 或
import models
model = models.AlexNet()
# 或
import models
model = getattr('models', 'AlexNet')()

 

上述在主函数中的代码中,其中最后写法最关键,这样意味着我们可以通过字符串直接指定使用的模型,而不需要使用判断语句,同时也不需要在每次新增加模型后都修改代码。

 

但是最好的方法,就是在新增模型后需要在 models.__init__.py
中加上 from .new_module import new_module
,避免使用第一种方法时报错,或者避免使用 model = getattr('models', 'AlexNet')()
时找不到该对象。

 

最后,在模型定义的时候,需要注意以下三点:

 

nn.Sequenetial

 

四、工具函数

 

在项目中,我们可能需要用到一些经常使用的方法,这些方法可以统一放入到 utils 文件夹中,需要时再导入。

 

在这个项目中,主要封装了可视化工具 visdom 的一些操作。

 

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Coding by https://www.cnblogs.com/nickchen121/
# Datatime:2021/5/3 10:23
# Filename:visualize.py
# Toolby: PyCharm
import visdom
import time
import numpy as np

class Visualizer(object):
    """
    封装了 visdom 的基本操作,但仍然可以通过 `self.vis.function`
    或者 `self.function` 调用原生的 visdom 接口
    例如:
    self.text('hello visdom')
    self.histogram(t.randn(1000))
    self.line(t.arange(0, 10), t.arange(1, 11))
    """
    def __init__(self, env='default', **kwargs):
        self.vis = visdom.Visdom(env=env, **kwargs)
        # 保存('loss', 23) 即 loss 的第 23 个点
        self.index = {}
        self.log_text = ''
    def reinit(self, env='default', **kwargs):
        """
        修改 visdom 的配置
        :param env:
        :param kwargs:
        :return:
        """
        self.vis = visdom.Visdom(env=env, **kwargs)
        return self
    def plot_many(self, d: dict):
        """
        一次 plot 多个
        :param d: dict(name, value) i.e. ('loss', 0.11)
        :return:
        """
        for k, v in d.items():
            self.plot(k, v)
    def img_many(self, d: dict):
        """
        处理多张图片
        :param d:
        :return:
        """
        for k, v in d.items():
            self.img(k, v)
    def plot(self, name, y, **kwargs):
        """
        self.plot('loss', 1.00)
        :param name: 
        :param y: 
        :param kwargs: 
        :return: 
        """
        x = self.index.get(name, 0)
        self.vis.line(Y=np.array([y]),
                      X=np.array([x]),
                      win=name,
                      opts=dict(title=name),
                      update=None if x == 0 else 'append',
                      **kwargs)
        self.index[name] = x + 1
    def img(self, name, img_, **kwargs):
        """
        self.img('input_img', t.Tensor(64, 64))
        self.img('input_imgs', t.Tensor(3, 64, 64))
        self.img('input_img', t.Tensor(100, 1, 64, 64))
        self.img('input_imgs', t.Tensor(100, 3, 64, 64), nrows=10)
        :param name:
        :param img_:
        :param kwargs:
        :return:
        """
        self.vis.images(img_.cpu().numpy,
                        win=name,
                        opts=dict(title=name),
                        **kwargs)
    def log(self, info, win='log_text'):
        """
        self.log({'loss':1, 'lr':0.0001}
        :param info:
        :param win:
        :return:
        """
        self.log_text += ('[{time}] {info} <br>'.format(
            time=time.strftime('%m%d_%H%M%S'),
            info=info
        ))
        self.vis.text(self.log_text, win)
    def __getattr__(self, name):
        """
        自定义的 plot,image,log,plot_many 等除外
        self.function 等价于 self.vis.function
        :param name:
        :return:
        """
        return getattr(self.vis, name)

 

五、配置文件

 

在模型定义、数据处理和训练过程中会产生许多变量,这些变量应该提供默认值,并且统一放在配置文件中。如此做的话,在后期调试、修改代码的时候会方便很多,在这里,我们把所有课配置项都放在 config.py 中。

 

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Coding by https://www.cnblogs.com/nickchen121/
# Datatime:2021/5/3 10:20
# Filename:config.py
# Toolby: PyCharm
class DefaultConfig(object):
    env = 'default'
    model = 'AlexNet'  # 使用的模型,名字必须与 models/__init__.py 中的名字一致
    train_data_root = './data/train/'  # 训练集存放路径
    test_data_root = './data/test1'  # 测试集存放路径
    load_model_path = 'checkpoints/model.pth'  # 加载预训练模型的路径,为 None 代表不加载
    batch_size = 128  # batch_size
    use_gpu = False  # use GPU or not
    num_workers = 4  # num of workers for loading data
    print_freq = 20  # print info every N batch
    debug_file = '/tmp/debug'  # if os.path.exists(debug_file): enter ipdb
    result_file = 'result.csv'
    max_epoch = 10
    lr = 0.1  # initial learning rate
    lr_decay = 0.95  # when val_loss increase, lr = lr*lr_decay
    weight_decay = 1e-4  # 损失函数

 

从上述代码中可以看出可配置的参数主要包括以下三类:

数据集参数(文件路径、batch_size 等)
训练参数(学习率、训练 epoch 等)
模型参数

定义好了上述配置参数后,可以在程序中这样使用配置参数:

 

import models
from config import DefaultConfig
opt = DefaultConfig()
lr = opt.lr
model = getattr(models, opt.model)
dataset = DogCat(opt.traini_data_error)

 

上述所说的都是默认参数,在默认配置类中,我们还可以提供一个更新函数,根据字典更新配置参数。

 

def parse(self, kwargs: dict):
        """
        根据字典 kwargs 更新 config 参数
        :param kwargs:
        :return:
        """
        # 更新配置参数
        for k, v in kwargs.items():
            if not hasattr(self, k):
                warnings.warn(f"Warning: opt has not attribut {k}")
            setattr(self, k, v)
        # 打印配置信息
        print('user config: ')
        for k, v in self.__class__.__dict__.items():  # type:str
            if not k.startswith('__'):
                print(k, getattr(self, k))

 

当然,在实际使用时没必要每次修改 config.py,只需要通过命令行传入所需要的参数,覆盖默认配置就行,例如

 

opt = DefaultConfig()
new_config = {'lr': 0.1, 'use_gpu': False}
opt.parse(new_config)
opt.lr == 0.1

 

六、main.py

 

6.1 命令行工具 fire

 

在讲解 main 文件前,我们先熟悉一个我们可能可以用到的一个 命令行工具 fire
,可以通过 pip install fire
安装,下面介绍下 fire 的基础用法,假设 example.py 文件代码如下:

 

# example.py
import file

def add(x, y):
    return x + y

def mul(**kwargs):
    a = kwargs['a']
    b = kwargs['b']
    return a * b

if __name__ == '__main__':
    fire.Fire()

 

那我们可以在命令行中通过以下语句调用 example 文件中定义的函数:

 

python example.py add 1 2  # 执行 add(1, 2)
python example.py mul --a=1 --b=2  # 执行 mul(a=1, b=2), kwargs={'a':1, 'b':2}
python example.py add --x=1 --y=2  # 执行 add(x=1, y=2)

 

从上述代码可以看出,只要在程序中运行了 fire.Fire(),就可以通过命令行参数 `python file
[args,] {–kwargs,}。当然,fire 还支持更多的高级功能,具体可以参考 官方指南

 

6.2 main.py的代码组织结构

 

在我们这个项目的 main.py 中主要包括以下四个函数,其中三个需要命令行执行,main.py 的代码组织结构如下所示:

 

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Coding by https://www.cnblogs.com/nickchen121/
# Datatime:2021/5/3 10:20
# Filename:main.py
# Toolby: PyCharm
import os
import csv
import ipdb
import fire
import torch as t
from torchnet import meter
from inspect import getsource
from torch.nn import functional
from torch.autograd import Variable
from torch.utils.data import DataLoader
import models
from config import opt
from data.dataset import DogCat
from utils.visualize import Visualizer

def train(**kwargs):
    """
    训练
    :param kwargs:
    :return:
    """
    pass

def val(model, dataloader):
    """
    计算模型在验证集上的准确率等信息,用来辅助训练
    :param model:
    :param dataloader:
    :return:
    """
    pass

def test(**kwargs):
    """
    测试(inference)
    :param kwargs:
    :return:
    """
    pass

def dc_help():
    """
    打印帮助的信息
    :return:
    """
    print('help')

if __name__ == '__main__':
    fire.Fire()

 

main.py 搭建好这样的组织结构后,可以通过 python main.py <function> --args==xx
的方式执行训练或测试。

 

6.3 训练

 

训练的主要步骤如下:

定义网络
定义数据
定义损失函数和优化器
计算重要指标

开始训练
训练网络
可视化各种指标
计算在验证集上的指标

其中训练函数的代码如下:

 

def train(**kwargs):
    """
    训练
    :param kwargs:
    :return:
    """
    # 根据命令行参数更新配置
    opt.parse(kwargs)
    vis = Visualizer(opt.env)
    # step1:模型
    model = getattr(models, opt.model)()
    if opt.load_model_path:
        model.load(opt.load_model_path)
    if opt.use_gpu: model.cuda()
    # step2:数据
    train_data = DogCat(opt.train_data_root, train=True)
    val_data = DogCat(opt.train_data_root, train=False)
    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers)
    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers)
    # step3:目标函数和优化器
    criterion = t.nn.CrossEntropyLoss()
    lr = opt.lr
    optimizer = t.optim.Adam(model.parameters(),
                             lr=lr,
                             weight_decay=opt.weight_decay)
    # step4:统计指标:平滑处理之后的损失,还有混淆矩阵
    loss_meter = meter.AverageValueMeter()  # 平均损失
    confusion_matrix = meter.ConfusionMeter(2)  # 混淆矩阵
    previous_loss = 1e100
    # 训练
    for epoch in range(opt.max_epoch):
        loss_meter.reset()
        confusion_matrix.reset()
        for ii, (data, label) in enumerate(train_dataloader):
            # 训练模型参数
            inp = Variable(data)
            target = Variable(label)
            if opt.use_gpu:
                inp = inp.cuda()
                target = target.cuda()
            optimizer.zero_grad()
            score = model(inp)
            loss = criterion(score, target)
            loss.backward()
            optimizer.step()
            # 更新统计指标及可视化
            loss_meter.add(loss.data[0])
            confusion_matrix.add(score.data, target.data)
            if ii % opt.print_freq == opt.print_freq - 1:
                vis.plot('loss', loss_meter.value()[0])
                # 如果需要的话,进入 debug 模式
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
        model.save()
        # 计算验证集上的指标及可视化
        val_cm, val_accuracy = val(model, val_dataloader)
        vis.plot('val_accuracy', val_accuracy)
        vis.log('epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm{val_cm}'
                .format(epoch=epoch,
                        loss=loss_meter.value()[0],
                        val_cm=str(val_cm.value()),
                        train_cm=str(confusion_matrix.value()),
                        lr=lr))
        # 如果损失不再下降,则降低学习率
        if loss_meter.value()[0] > previous_loss:
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        previous_loss = loss_meter.value()[0]

 

6.3.1 torchnet 中的 meter

 

在训练的代码中,这里用到了 PyTorchNet
里的一个工具:meter。由于 PyTorchNet 是从 TorchNet 中迁移来的,提供了很多有用的工具,但目前的开发和文档都不是特别完善,这里不多做赘述,只讲上述用到的几个方法。

 

mter 提供了一些轻量级工具,可以帮助用户快速的统计训练过程中的一些指标。

 

* AverageValueMeter 能够计算所有数的平均值和标准差,可以用来统计一个 epoch 中损失的平均值

 

* confusionmeter 用来统计分类问题中的分类情况,是一个比准确率更详细的统计指标,给出的是一个混淆矩阵

 

混淆矩阵举例:

样本判为狗判为猫
实际是猫3515
实际是狗991

 

注: 想详细了解混淆矩阵的在第七小节

 

6.4 验证

 

验证相比较训练来说简单很多,但是需要注意把模型置于验证模式( model.eval()
),验证完成后还需要把它设置回训练模式( model.train()
),这两句代码会影响 BatchNorm 和 Dropout 等层的运行模式。验证模型准确率的代码如下:

 

def val(model, dataloader):
    """
    计算模型在验证集上的准确率等信息,用来辅助训练
    :param model:
    :param dataloader:
    :return:
    """
    # 把模型设置为验证模式
    model.eval()
    confusion_matrix = meter.ConfusionMeter(2)
    for ii, data in enumerate(dataloader):
        inp, label = data
        val_inp = Variable(inp, volatile=True)
        val_label = Variable(label.long(), volatile=True)
        if opt.use_gpu:
            val_inp = val_inp.cuda()
            val_label = val_label.cuda()
        score = model(val_inp)
        confusion_matrix.add(score.data.squeeze(), label.long())
    # 把模型恢复为训练模式
    model.train()
    cm_value = confusion_matrix.value()
    accuracy = 100. * (cm_value[0][0] + cm_value[1][1]) / (cm_value.sum())
    return confusion_matrix, accuracy

 

6.5 测试

 

测试的时候,需要计算每个样本属于狗的概率,并把结果保存为 csv 文件,测试的代码和验证比较相似,但需要自己加载模型和数据。

 

def write_csv(results, file_name):
    with open(file_name, 'w') as f:
        writer = csv.writer(f)
        writer.writerow(['id', 'label'])
        writer.writerows(results)

def test(**kwargs):
    """
    测试(inference)
    :param kwargs:
    :return:
    """
    opt.parse(kwargs)
    # 模型
    model = getattr(models, opt.model)().eval()
    if opt.load_model_path:
        model.load(opt.load_model_path)
    if opt.use_gpu: model.cuda()
    # 数据
    train_data = DogCat(opt.test_data_root, test=True)
    test_dataloader = DataLoader(train_data,
                                 batch_sampler=opt.batch_size,
                                 shuffle=False,
                                 num_workers=opt.num_workers)
    results = []
    for ii, (data, path) in enumerate(test_dataloader):
        inp = Variable(data, volatile=True)
        if opt.use_gpu: inp = inp.cuda()
        score = model(inp)
        probability = probability = functional.softmax(score, dim=1)[:, 0].detach().tolist()
        batch_results = [(path_, probability_) for path_, probability_ in zip(path, probability)]
        results += batch_results
    write_csv(results, opt.result_file)
    return results

 

6.6 帮助函数

 

为了让他人方便使用,程序中应该还需要提供一个帮助函数,用于说明函数是如何使用的。

 

程序的命令行接口有很多参数,如果手动用字符串表示不仅复杂,而且后期修改 config 文件时还需要修改对应的帮助信息。为此,这里使用 Python 标准库中的 inspect 方法,可以自动获取 config 的源代码。

 

dg_help 的代码如下:

 

def dc_help():
    """
    打印帮助的信息
    :return:
    """
    print('''
    usage:python{0} <function> [--args=value,]
    <function> := train | test | help
    example:
        python {0} train --env='env0701' --lr=0.01
        python {0} test --dataset='path/to/dataset/root/'
        python {0} help
    avaiable args:
    '''.format(__file__))
    source = (getsource(opt.__class__))  # 获取配置信息
    print(source)

 

七、使用

 

如 dc_help 函数打印的信息描述的一样,可以通过命令行参数指定变量名。下面是三个使用例子,fire 会把包含 “-” 命令行参数自动转成下划线 “_”,也会把非数字的数值转成字符串,所以 --train--data-root=data/train
--train_data_root = 'data/train'
是等价的。

 

感兴趣的可以把数据集下载下来进行测试: 猫狗分类数据集

 

由于本章只是讲解项目架构,我就不做测试,但是代码应该没什幺大问题,修修补补就行了。

 

想要具体代码的可以加我微信:chenyoudea,但是没必要找我要,我也没有尝试去跑通这个代码,并且我也没有下载数据集,因为这一章没必要。

 

# 训练模型
python main.py train
    --train-data-root=data/train/
    --load-model-path=None
    --lr=0.005
    --batch-size=32
    --model='ResNet34'
    --max-epoch=20
    
python main.py train --train-data-root=data/train/ --load-model-path=None --lr=0.005 --batch-size=32 --model='ResNet34' --max-epoch=20
    
# 测试模型
python main.py test
    --test-data-root=data/test1
    --load-model-path=None
    --batch-szie=128
    --model='ResNet34'
    --num-workers=12
    
# 打印帮助信息
python main.py dc_help

 

八、争议

 

这里还是多说一嘴,因为这个风格更多的是书籍作者陈云老师的风格,并不是说以后你写的代码都要以这个为标准,这个项目架构更多的是作为一个题意或一种参考。

 

也就是说,不要把本篇文章的观点作为一个必须遵守的规范,但是前期的学习可以按照这个架构来,这样不容易犯错。但是,对于未来你遇到的很多项目,尤其对于每个公司的项目,项目架构相信都是不一样的,不唯经验主义,不唯教条主义,这才是一个码农想进阶的必经之路。

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注