本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.
活动地址:CSDN21天学习挑战赛
目录
项目数据及源码
可在github下载:
https://github.com/chenshunpeng/Pytorch-competitor-MNIST-dataset-classification
任务描述
我们需要通过对手写数字数据集Mnist的训练,实现对于一个手写数字图像,判断其对应的数字值,判断方法是通过比较其和 0~9
这10个数字的相似程度,选出相似度最高的作为其识别的数字值,如下图, 0~9
这10个数字的相似程度最高的是 9
,为 0.87
,因此其识别结果为9
读取Mnist数据集
数据集地址:
http://yann.lecun.com/exdb/mnist/ (也可在github项目中找到)
数据集介绍:
Dataset之MNIST:MNIST(手写数字图片识别+ubyte.gz文件)数据集简介、下载、使用方法(包括数据增强)之详细攻略
train-images-idx3-ubyte.gz: training set images (9912422 bytes) train-labels-idx1-ubyte.gz: training set labels (28881 bytes) t10k-images-idx3-ubyte.gz: test set images (1648877 bytes) t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
MNIST是一个非常有名的手写体数字识别数据集(手写数字灰度图像数据集),在很多资料中,这个数据集都会被用作深度学习的入门样例
MNIST数据集是NIST数据集的一个子集,由 0~9
的数字图像构成的,每一张图片都有对应的标签数字,训练图像一共高60000张,供研究人员训练出合适的模型。测试图像一共高10000 张,供研究人员测试训练的模型的性能
其每张图片是包含28像素×28像素的灰度图像(1通道),各个像素的取值在0到255之间,每个图像数据都相应地标有数字标签
每张图片都由一个28×28的矩阵表示,且数字都会出现在图片的正中间,处理后的每一张图片是一个长度为784的一维数组(28*28=784),这个数组中的元素对应了图片像素矩阵中的每一个数字。
# 将matplotlib的图表直接嵌入到Notebook之中,或者使用指定的界面库显示图表 %matplotlib inline from pathlib import Path import requests DATA_PATH = Path("data") PATH = DATA_PATH / "mnist" PATH.mkdir(parents=True, exist_ok=True) FILENAME = "mnist.pkl.gz" import pickle import gzip with gzip.open((PATH / FILENAME).as_posix(), "rb") as f: ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
查看数据集信息:
from matplotlib import pyplot import numpy as np pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray") print(x_train.shape) # 50000个样本,每个图像是28*28*1
我们可以通过 x_train[0]
看到这个数字的矩阵表示,但是由于无法按照28×28显示,看不出来其是 5 的轮廓,矩阵表示如下:
tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0117, 0.0703, 0.0703, 0.0703, 0.4922, 0.5312, 0.6836, 0.1016, 0.6484, 0.9961, 0.9648, 0.4961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1172, 0.1406, 0.3672, 0.6016, 0.6641, 0.9883, 0.9883, 0.9883, 0.9883, 0.9883, 0.8789, 0.6719, 0.9883, 0.9453, 0.7617, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1914, 0.9297, 0.9883, 0.9883, 0.9883, 0.9883, 0.9883, 0.9883, 0.9883, 0.9883, 0.9805, 0.3633, 0.3203, 0.3203, 0.2188, 0.1523, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0703, 0.8555, 0.9883, 0.9883, 0.9883, 0.9883, 0.9883, 0.7734, 0.7109, 0.9648, 0.9414, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3125, 0.6094, 0.4180, 0.9883, 0.9883, 0.8008, 0.0430, 0.0000, 0.1680, 0.6016, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0547, 0.0039, 0.6016, 0.9883, 0.3516, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5430, 0.9883, 0.7422, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0430, 0.7422, 0.9883, 0.2734, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1367, 0.9414, 0.8789, 0.6250, 0.4219, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3164, 0.9375, 0.9883, 0.9883, 0.4648, 0.0977, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1758, 0.7266, 0.9883, 0.9883, 0.5859, 0.1055, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0625, 0.3633, 0.9844, 0.9883, 0.7305, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9727, 0.9883, 0.9727, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1797, 0.5078, 0.7148, 0.9883, 0.9883, 0.8086, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1523, 0.5781, 0.8945, 0.9883, 0.9883, 0.9883, 0.9766, 0.7109, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0938, 0.4453, 0.8633, 0.9883, 0.9883, 0.9883, 0.9883, 0.7852, 0.3047, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0898, 0.2578, 0.8320, 0.9883, 0.9883, 0.9883, 0.9883, 0.7734, 0.3164, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0703, 0.6680, 0.8555, 0.9883, 0.9883, 0.9883, 0.9883, 0.7617, 0.3125, 0.0352, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2148, 0.6719, 0.8828, 0.9883, 0.9883, 0.9883, 0.9883, 0.9531, 0.5195, 0.0430, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5312, 0.9883, 0.9883, 0.9883, 0.8281, 0.5273, 0.5156, 0.0625, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000])
将数据需转换成tensor:
import torch x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid)) n, c = x_train.shape x_train, x_train.shape, y_train.min(), y_train.max() print(x_train, y_train) print(x_train.shape) print(y_train.min(), y_train.max())
结果:
设计全连接神经网络
全连接网络中,要求输入的是一个矩阵,因此需要将1x28x28的这个三阶的张量变成一个一阶的向量,因此将图像的每一行的向量横着拼起来变成一串,这样就变成了一个维度为1×784的向量,一共输入N个手写数图,因此,输入矩阵维度为(N,784),这样就可以设计我们的模型,如下图所示
构造Mnist_NN类,定义函数
需要注意:
Mnist_NN
类必须继承 nn.Module
且在其构造函数中需调用 nn.Module
的构造函数
无需写反向传播函数, nn.Module
能够利用 autograd
自动实现反向传播Module
中的可学习参数可以通过 named_parameters()
或者 parameters()
返回迭代器
from torch import nn from torch import optim import torch.nn.functional as F from torch.utils.data import TensorDataset from torch.utils.data import DataLoader import numpy as np # 继承nn.Module class Mnist_NN(nn.Module): # 构造函数 def __init__(self): # 调用nn.Module的构造函数 super().__init__() self.hidden1 = nn.Linear(784, 128) # 隐层1 self.hidden2 = nn.Linear(128, 256) # 隐层2 self.out = nn.Linear(256, 10) # 输出层 # 前向传播 def forward(self, x): # import torch.nn.functional as F x = F.relu(self.hidden1(x)) x = F.relu(self.hidden2(x)) x = self.out(x) return x
创建 Mnist_NN
类对象 net
并查看信息:
net = Mnist_NN() print(net)
输出:
可以打印我们定义好名字里的权重和偏置项:
for name, parameter in net.named_parameters(): print(name, parameter, parameter.size())
结果:
hidden1.weight Parameter containing: tensor([[-0.0107, 0.0176, 0.0235, ..., 0.0040, -0.0234, 0.0087], [ 0.0177, -0.0273, 0.0112, ..., -0.0134, 0.0282, -0.0013], [ 0.0139, -0.0125, 0.0143, ..., -0.0239, 0.0263, -0.0089], ..., [-0.0204, 0.0160, 0.0061, ..., -0.0239, -0.0082, -0.0247], [ 0.0070, -0.0266, -0.0093, ..., -0.0144, 0.0022, 0.0010], [ 0.0227, 0.0055, 0.0275, ..., -0.0272, 0.0136, -0.0164]], requires_grad=True) torch.Size([128, 784]) hidden1.bias Parameter containing: tensor([-0.0097, 0.0237, 0.0018, -0.0330, -0.0280, -0.0191, -0.0255, 0.0288, 0.0225, 0.0101, -0.0063, -0.0276, 0.0091, 0.0075, -0.0313, 0.0057, -0.0356, -0.0265, 0.0286, -0.0057, -0.0100, -0.0276, 0.0178, -0.0170, -0.0174, 0.0337, 0.0259, -0.0143, 0.0314, 0.0331, 0.0341, 0.0189, -0.0315, -0.0170, 0.0237, 0.0156, -0.0345, 0.0154, 0.0197, 0.0305, 0.0349, -0.0326, 0.0193, -0.0336, 0.0142, 0.0262, 0.0215, 0.0004, 0.0243, 0.0236, -0.0195, -0.0208, 0.0333, -0.0104, 0.0033, 0.0118, 0.0113, -0.0340, 0.0155, 0.0261, -0.0089, 0.0287, -0.0242, 0.0022, -0.0165, -0.0296, 0.0008, 0.0316, -0.0224, -0.0037, 0.0105, 0.0057, 0.0285, -0.0158, -0.0013, -0.0340, 0.0287, -0.0043, -0.0148, -0.0273, -0.0066, 0.0082, -0.0170, -0.0021, -0.0280, 0.0211, -0.0165, -0.0103, 0.0152, -0.0128, -0.0211, -0.0180, -0.0097, 0.0089, 0.0338, 0.0322, -0.0210, -0.0235, -0.0123, -0.0219, -0.0201, 0.0003, -0.0106, -0.0303, -0.0003, -0.0157, 0.0188, 0.0179, 0.0237, -0.0351, -0.0146, -0.0205, -0.0284, 0.0218, 0.0107, -0.0353, 0.0253, -0.0196, -0.0317, -0.0294, 0.0184, 0.0201, 0.0059, 0.0260, 0.0134, -0.0217, 0.0091, -0.0089], requires_grad=True) torch.Size([128]) hidden2.weight Parameter containing: tensor([[-0.0658, 0.0262, 0.0356, ..., 0.0520, -0.0872, 0.0459], [-0.0443, -0.0812, -0.0046, ..., 0.0819, -0.0386, -0.0344], [-0.0703, 0.0753, -0.0350, ..., -0.0035, 0.0188, 0.0194], ..., [ 0.0556, 0.0688, -0.0311, ..., -0.0033, 0.0832, -0.0497], [ 0.0164, 0.0710, 0.0368, ..., 0.0303, 0.0231, 0.0512], [-0.0437, 0.0875, 0.0315, ..., 0.0002, 0.0679, -0.0412]], requires_grad=True) torch.Size([256, 128]) hidden2.bias Parameter containing: tensor([ 7.7913e-03, -5.2409e-02, 3.7981e-02, 6.4097e-02, 6.5983e-02, -1.2665e-02, -5.3630e-02, 1.8194e-02, 2.8534e-02, 8.3733e-02, 5.3927e-02, 2.3522e-02, -2.2915e-02, 7.9818e-02, -4.8618e-02, -4.9321e-02, -6.4636e-02, 4.5667e-02, 6.2186e-02, 2.9977e-02, -3.8158e-02, 6.4900e-02, -5.5211e-02, -4.5465e-02, -7.5447e-02, -1.3676e-03, 1.8499e-02, 2.6505e-02, -1.3459e-02, 6.3754e-02, -3.7523e-02, 5.7949e-02, -5.9734e-02, -8.6329e-02, 2.9193e-02, 2.0645e-02, 2.8751e-02, 6.2095e-02, 6.5391e-02, -1.3178e-02, 5.2374e-02, -5.1765e-02, -5.7692e-02, -4.6615e-02, -1.6571e-02, -6.7677e-02, -6.8337e-02, -4.4569e-02, -1.3499e-02, -7.0806e-02, 1.7268e-02, 7.9308e-02, -9.2949e-03, 8.3358e-02, -2.8339e-03, 3.6183e-02, -3.0781e-03, -7.8056e-02, -2.5781e-02, -6.1548e-02, -4.2550e-03, 8.4365e-02, 7.6643e-02, 2.6072e-03, 3.8844e-02, -9.1026e-03, 1.7072e-02, 1.5069e-02, -1.5344e-02, -7.1375e-02, -2.4087e-02, 4.8563e-02, 4.3171e-02, 3.7335e-02, 3.9004e-02, 4.7122e-02, 6.3475e-02, 4.2615e-02, -6.1060e-02, 1.4865e-02, 4.5167e-02, -8.0974e-02, 5.3717e-03, -3.9014e-02, 8.3588e-02, 6.5867e-02, -3.4913e-02, 5.8872e-02, 6.7077e-02, -6.3365e-02, 8.6366e-02, 3.5593e-02, 4.6238e-02, 8.3289e-02, -1.4793e-02, 7.2298e-02, 6.0482e-02, 4.2920e-02, 3.9899e-02, 8.2298e-02, 4.3614e-02, 8.3762e-03, 6.7424e-02, -5.9824e-02, -5.2346e-02, 5.3317e-02, -1.8010e-02, 7.9718e-03, 4.9618e-02, 5.7588e-03, 2.6586e-02, 4.7773e-02, -7.4746e-02, -4.2066e-03, 6.3242e-02, -8.4219e-03, -7.7916e-02, -7.9803e-02, 1.4334e-02, 5.2814e-02, -7.5703e-02, 8.8523e-03, 6.0214e-03, 5.8813e-02, 4.3685e-02, 3.1810e-03, 5.6022e-02, -6.4101e-02, -6.3819e-02, -8.0192e-02, 2.3717e-02, 9.3828e-03, -2.4051e-02, -1.5994e-02, -6.8268e-02, -8.3660e-02, -7.3033e-02, -6.6568e-02, 3.7064e-02, -3.3497e-02, -8.7144e-02, 8.3359e-02, -1.3661e-02, 3.5242e-02, 3.0770e-02, -2.1677e-02, -7.5600e-02, -2.8537e-02, -1.9357e-02, -5.9502e-02, 7.9158e-02, -2.8801e-02, -2.2144e-02, 8.5924e-04, 7.5870e-02, 6.6614e-02, 1.4565e-02, -5.7472e-02, 8.0418e-02, 6.6934e-02, 3.2934e-02, 5.2901e-03, -7.0742e-03, 4.2174e-02, 5.4780e-02, -6.9979e-02, 5.7612e-02, 4.3069e-02, -1.9059e-02, 5.2661e-02, 3.0751e-02, -5.5104e-02, -5.3951e-02, 9.0439e-03, -2.0585e-02, 2.0851e-02, -3.0479e-02, 4.0783e-03, 2.2134e-02, 6.5000e-02, 8.0417e-02, -4.5733e-02, 3.5371e-02, 2.2602e-02, 3.9445e-02, 5.0051e-02, 1.1277e-02, 8.4714e-03, -3.4974e-02, 1.4301e-02, 5.3342e-02, 2.7742e-02, -8.6245e-02, 4.0869e-02, -8.0224e-02, -3.9399e-02, 8.7867e-02, 5.3911e-02, 4.4785e-02, -8.7924e-02, 5.3280e-02, 5.5927e-02, 3.0065e-02, 4.8404e-02, 5.4177e-02, -6.6974e-02, 3.5416e-02, 8.9249e-03, 7.0158e-02, 2.6166e-02, 6.6212e-04, 8.5239e-02, 3.1147e-02, 2.9362e-02, 8.2084e-02, -8.0664e-02, -3.9999e-02, 4.9067e-02, 6.4668e-02, -6.9497e-02, -4.6120e-02, 3.0965e-02, -5.0559e-02, 4.8063e-02, -6.1079e-02, 4.0454e-02, 7.1121e-02, 6.7732e-02, 1.7263e-02, 3.8927e-02, 3.4393e-02, 2.5543e-02, -7.6177e-02, 1.5727e-02, -3.0954e-02, 6.5176e-02, 8.5865e-03, 4.0888e-02, -7.4767e-05, 6.3285e-02, 2.6874e-02, -4.7549e-02, -2.6836e-02, -5.2410e-02, -4.1517e-02, -6.4450e-03, -5.6177e-02, 3.9314e-02, -5.7746e-02, 4.6241e-02, -7.3782e-02, 8.7160e-02, 8.6259e-02, 8.5354e-02, -2.9345e-02, 1.3077e-02], requires_grad=True) torch.Size([256]) out.weight Parameter containing: tensor([[-0.0613, -0.0281, -0.0492, ..., 0.0526, 0.0189, -0.0455], [-0.0086, -0.0281, -0.0385, ..., -0.0198, -0.0447, -0.0342], [ 0.0407, 0.0162, -0.0182, ..., 0.0353, -0.0350, 0.0405], ..., [ 0.0398, 0.0623, -0.0503, ..., 0.0261, -0.0479, -0.0239], [-0.0221, -0.0278, 0.0564, ..., 0.0249, -0.0339, -0.0200], [ 0.0242, -0.0149, 0.0027, ..., -0.0408, 0.0173, -0.0111]], requires_grad=True) torch.Size([10, 256]) out.bias Parameter containing: tensor([-0.0526, 0.0188, 0.0049, -0.0456, -0.0164, -0.0436, 0.0448, 0.0018, -0.0373, -0.0142], requires_grad=True) torch.Size([10])
使用TensorDataset和DataLoader来简化数据处理:
get_data()
函数:
shuffle
即是否对数据集进行洗牌操作,默认设置为False(数据类型 bool)
将输入数据的顺序打乱,是为了使数据更有独立性,但如果数据是有序列特征的,就不要设置成True了
一般对训练集进行shuffle操作而对测试集保留原有的顺序结构(原始数据在样本均衡的情况下可能是按照某种顺序进行排列,如前半部分为某一类别的数据,后半部分为另一类别的数据,打乱之后数据的排列就会拥有一定的随机性,减小模型抖动)
def get_data(train_ds, valid_ds, bs): return ( DataLoader(train_ds, batch_size=bs, shuffle=True), DataLoader(valid_ds, batch_size=bs * 2), )
get_model()
函数:
在 PyTorch的 torch.optim
包中提供了非常多的可实现参数自动优化的类,如 SGD 、AdaGrad 、RMSProp 、Adam等优化算法,这些类都可以被直接调用
本次实验使用了最基本的优化算法SGD
def get_model(): model = Mnist_NN() return model, optim.SGD(model.parameters(), lr=0.001)
loss_batch()
函数:
def loss_batch(model, loss_func, xb, yb, opt=None): loss = loss_func(model(xb), yb) if opt is not None: loss.backward() opt.step() opt.zero_grad() return loss.item(), len(xb)
fit()
函数:
一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout,将测试集的数据送入神经网络模型进行训练,计算模型在测试集上的综合表现能力
def fit(steps, model, loss_func, opt, train_dl, valid_dl): for step in range(steps): model.train() for xb, yb in train_dl: loss_batch(model, loss_func, xb, yb, opt) model.eval() with torch.no_grad(): losses, nums = zip( *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]) val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums) print('当前step:' + str(step), '验证集损失:' + str(val_loss))
进行训练
bs
即 batch_size
(数据类型 int),在进行深度学习处理时,常常将数据集划分为一个个的批次,每个批次有固定的数据数目,在此就是指定一个批次的数据量
train_ds = TensorDataset(x_train, y_train) valid_ds = TensorDataset(x_valid, y_valid) bs = 64 train_dl, valid_dl = get_data(train_ds, valid_ds, bs) model, opt = get_model() loss_func = F.cross_entropy # 交叉熵损失函数 fit(25, model, loss_func, opt, train_dl, valid_dl)
结果:
当前step:0 验证集损失:2.2809557510375975 当前step:1 验证集损失:2.2500623081207274 当前step:2 验证集损失:2.202859774017334 当前step:3 验证集损失:2.123643782043457 当前step:4 验证集损失:1.9911612365722657 当前step:5 验证集损失:1.7912375587463378 当前step:6 验证集损失:1.5452837438583373 当前step:7 验证集损失:1.3032891147613526 当前step:8 验证集损失:1.1027766933441163 当前step:9 验证集损失:0.949706922531128 当前step:10 验证集损失:0.8340907591819763 当前step:11 验证集损失:0.7464724873542785 当前step:12 验证集损失:0.6767623687744141 当前step:13 验证集损失:0.622122283744812 当前step:14 验证集损失:0.5775999296188354 当前step:15 验证集损失:0.5417200242042541 当前step:16 验证集损失:0.5122299160003662 当前step:17 验证集损失:0.4875089702606201 当前step:18 验证集损失:0.46718254098892215 当前step:19 验证集损失:0.4494625943660736 当前step:20 验证集损失:0.4347919206619263 当前step:21 验证集损失:0.4215654832363129 当前step:22 验证集损失:0.41056136293411255 当前step:23 验证集损失:0.4001917915582657 当前step:24 验证集损失:0.39120743613243103
预测结果可视化
predicted = model(x_train[:]).data.numpy() res=np.argmax(predicted, axis=1) import matplotlib.pyplot as plt fig=plt.figure() plt.figure(figsize=(12,5)) for i in range(30): plt.subplot(5,6,i+1) plt.tight_layout() plt.imshow(x_train[i].reshape((28, 28)), cmap="gray") plt.title("True value: {} predictive value: {}".format(y_train[i],res[i])) plt.xticks([]) plt.yticks([])
结果:
Be First to Comment