Press "Enter" to skip to content

基于 Pytorch 的多类别图像分类实战

 

本篇基于 Pytorch 完成一个多类别图像分类实战。

 

1 简介

 

实现一个完整的图像分类任务,大致需要分为五个步骤:

 

1、选择开源框架

 

目前常用的深度学习框架主要包括 tensorflow、caffe、pytorch、mxnet 等;

 

2、构建并读取数据集

 

根据任务需求搜集相关图像搭建相应的数据集,常见的方式包括:网络爬虫、实地拍摄、公共数据使用等。随后根据所选开源框架读取数据集。

 

3、框架搭建

 

选择合适的网络模型、损失函数以及优化方式,以完成整体框架的搭建

 

4、训练并调试参数

 

通过训练选定合适超参数

 

5、测试准确率

 

在测试集上验证模型的最终性能

 

本文利用 Pytorch 框架,按照上述结构实现一个基本的图像分类任务,并详细阐述其中的细节及注意事项。

 

2 数据集

 

 

本次实战选择的数据集为 Kaggle 竞赛中的细胞数据集,共包含 9961 个训练样本,2491 个测试样本,可以分为嗜曙红细胞、淋巴细胞、单核细胞、中性白细胞 4 个类别,图片大小为 320×240。

 

Pytorch 中封装了相应的数据读取的类函数,通过调用 torch.utils.data.Datasets 函数,则可以实现读取功能。

 

 

init()模块用来定义相关的参数, len ()模块用来获取训练样本个数, getitem ()模块则用来获取每张具体的图片,在读取图片时其可以通过 opencv 库、PIL 库等进行读取,具体代码如下:

 

数据集

 

class dataset(data.Dataset):

 

# 参数预定义

 

def init (self, anno_pd, transforms=None):

 

self.paths = anno_pd[‘ImageName’].tolist()

 

self.labels = anno_pd[‘label’].tolist()

 

self.transforms = transforms

 

# 返回图片个数

 

def len (self):

 

return len(self.paths)

 

# 获取每个图片

 

def getitem (self, item):

 

img_path =self.paths[item]

 

img_id =img_path.split(“/”)[-1]

 

img =cv2.imread(img_path)

 

img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)

 

if self.transforms is not None:

 

img = self.transforms(img)

 

label = self.labels[item]

 

return torch.from_numpy(img).float(), int(label)

 

此外,需要定义图像增强模块,即上述代码中的 transform,通常采取的操作为翻转、剪切等,关于图像增强的具体介绍可以参考公众号前作。

 

【技术综述】深度学习中的数据增强方法都有哪些?

 

需要特别强调的是对图像进行去均值处理,很多同学不明白为何要减去均值,其主要的原因是图像作为一种平稳的数据分布,通过减去数据对应维度的统计平均值,可以消除公共部分,以凸显个体之间的特征和差异。进行去均值前后操作后的图像对比如下:

 

 

3 框架搭建

 

本次实战主要选取了 VGG16、Resnet50、InceptionV4 三个经典网络,也是对前篇文章的一个总结。

 

损失函数则选择交叉熵损失函数: 【技术综述】一文道尽 softmax loss 及其变种

 

优化方式选择 SGD、Adam 优化两种: 【模型训练】SGD 的那些变种,真的比 SGD 强吗

 

4 训练及参数调试

 

初始学习率设置为 0.01,batch size 设置为 8,衰减率设置为 0.00001,迭代周期为 15,在不同框架组合下的最佳准确率和最低 loss 如下图所示:

 

 

可以发现在验证集上 Resnet-50+SGD+Cross Entropy 的组合下取得了 99% 左右的准确率,相反 VGG-16 结果则稍微差一些。

 

最佳组合下的准确率走势曲线如下图所示:

 

 

5 测试

 

对上述模型分别在测试集上进行测试,所获得的结果如下图所示,整体精度比训练集上约下降了一个百分点:

 

 

总结

 

以上就是整个多类别图像分类实战的过程,由于时间限制,本次实战并没有对多个数据集进行练,因此没有列出同一模型在不同数据集上的表现。

 

作者介绍

 

郭冰洋,公众号“有三 AI”作者。该公号聚焦于让大家能够系统性地完成 AI 各个领域所需的专业知识的学习。

 

原文链接

 

https://mp.weixin.qq.com/s/jPpZLYXQBX7l5AUfFV5n3g

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注