Press "Enter" to skip to content

如何入门Pytorch之四:搭建神经网络训练MNIST

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

上一节我们学习了Pytorch优化网络的基本方法,本节我们将以MNIST数据集为例,通过搭建一个完整的神经网络,来加深对Pytorch的理解。

 

一、数据集

 

MNIST是一个非常经典的数据集,下载链接: http://yann.lecun.com/exdb/mnist/

 

下载下来的文件如下:

 

 

该手写数字数据库具有60,000个示例的训练集和10,000个示例的测试集。它是NIST提供的更大集合的子集。数字已经过尺寸标准化,并以固定尺寸的图像为中心。

 

手写数字识别是一个比较简单的任务,它是一个10分类问题,(0-9),之所以选这个数据集,是因为识别难度低,计算量小,数据容易获得。

 

二、模型搭建

 

1、网络节点的确定

 

对于不同的目的,网络的选择也是不一样的。一般来说,网络容量和数据集大小是对应的。一个小型数据集也只需要一个小型的网络。

 

这里有一个经验值:

 

1)model_size=sqrt(in_size*out_size)

 

2)model_size=log(in_size)

 

3)  model_size=sqrt(in_size*out_size)

 

model_size:网络的节点量

 

in_size:输入的节点量

 

out_size输出的节点量

 

2、导入pytorch包

 

import torch
import torchvision
import trochvision import datasets
import trochvision import transforms
from torch.autograd import Variable

 

3、获取训练集和测试集

 

#root用于指定数据集下载后的存放路径
#transform用于指定导入数据集需要对数据进行变换操作
#train指定在数据集下载后需要载入哪部分数据,true为训练集,false为测试集
data_train=datasets.MNIST(root="./data/",transform=transform,train=True,download=True)
data_test=datasets.MNIST(root='./data/',transform=transform,train=False)

 

4、数据预览和装载

 

#数据装载,可以理解为对图片的处理
#处理完成后,将图片送给模型训练,装载就是打包的过程
#dataset 用于指定载入的数据集名称
#batch_size设置了每个包的图片数据数据个数
#shuffle 装载过程将数据随机打乱并打包
data_loader_train=torch.utils.data.DataLoader(dataset=data_train,batch_size=64,shuffle=True)
data_loader_test=torch.utils.data.DataLoader(dataset=data_test,batch_size=64,shuffle=True)

Be First to Comment

发表评论

电子邮件地址不会被公开。 必填项已用*标注