Press "Enter" to skip to content

PyTorch入门:使用PyTorch搭建神经网络LeNet5

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

在本文中,我们基于PyTorch构建一个简单的神经网络LeNet5。

 

在阅读本文之前,建议您了解一些卷积神经网络的前置知识,比如卷积、Max Pooling和全连接层等等,可以看我写的相关文章: 李宏毅机器学习课程笔记-7.1CNN入门详解

 

通过阅读本文,您可以学习到如何使用PyTorch构建神经网络LeNet5。

 

模型说明

 

在本例中,我们使用如下图所示的神经网络模型:LeNet5。

 

 

该模型有1个输入层、2个卷积层、2次Max Pooling、2个全连接层和1个输出层。

 

输入层INPUT

1个channel,图片size是32×32。

 

卷积层C1

6个channel,特征图的size是28×28,即每个卷积核的size为(5,5),stride为1。

 

下采样操作S2

6个channel,特征图的size是14×14,即Max Pooling窗口size为(2,2)。

 

卷积层C3

16个channel,特征图的size是10×10,即每个卷积核的size为(5,5),stride为1。

 

下采样操作S4

16个channel,特征图的size是5×5,即Max Pooling窗口size为(2,2)。

 

全连接层F5

120个神经元。

 

全连接层F6

84个神经元。

 

输出层OUTPUT

10个神经元。

 

另外,除了输入层和输出层,剩下的卷积层、最大池化操作和全连接层后面都要加上Relu激活函数,下采样操作S4之后需要进行Flatten以和全连接层F5衔接起来。

 

代码实现

 

import torch
import torch.nn as nn
import torch.nn.functional as F
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        # 
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv3 = nn.Conv2d(6, 16, 5)
        # 全连接层
        self.fc5 = nn.Linear(in_features=16*5*5, out_features=120)
        self.fc6 = nn.Linear(120, 84)
        self.OUTPUT = nn.Linear(84, 10)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, (2, 2)) # Max pooling over a (2, 2) window
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2) # If the size is a square you can only specify a single number
        x = x.view(-1, 16*5*5) # Flatten
        x = F.relu(self.fc5(x))
        x = F.relu(self.fc6(x))
        x = self.OUTPUT(x)
        return x
net = LeNet5()
output = net(torch.rand(1, 1, 32, 32))
# print(output)

 

参考链接

 

https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html

 

其实本文内容主要是PyTorch的官方教程。

 

PyTorch官方教程中代码实现与图片所示的LeNet5不符(PyTorch官方教程代码中是3×3的卷积核,而图片中LeNet5是5×5的卷积核),本文中我是按照图片所示模型结构实现的。

 

其实PyTorch开发者和其他开发者也注意到了这一问题,详见:

 

https://github.com/pytorch/tutorials/pull/515

 

https://github.com/pytorch/tutorials/commit/630802450c13c78f02f744af1c47d1033b6fe206

 

https://github.com/pytorch/tutorials/pull/1257

 

Github(github.com): @chouxianyu

 

Github Pages(github.io):@臭咸鱼

 

知乎(zhihu.com): @臭咸鱼

 

博客园(cnblogs.com): @臭咸鱼

 

B站(bilibili.com): @绝版臭咸鱼

 

微信公众号: @臭咸鱼

 

转载请注明出处,欢迎讨论和交流!

Be First to Comment

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注