Press "Enter" to skip to content

【神经网络】(22) ConvMixer 代码复现,网络解析,附TensorFlow完整代码

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

大家好,今天和各位分享一下如何使用 TensorFlow 构建 ConvMixer 卷积神经网络模型.

 

我偶然间找到了这个网络,这是一个实现起来非常简单的模型,但是能够实现较好的精度表现,超过了 Vision Transformer 模型,有种大道至简的感觉。

 

论文地址: https://openreview.net/forum?id=TVHS5Y4dNvM

 

1. 引言

 

近年来 Transformer 模型在 CV 领域中不断挑战卷积神经网络的统治地位,出现了能和 CNN 扳手腕的 VisionTransformer 以及划时代的 SwinTransformer。这篇文章作者主要针对的是 VIT 模型, 他提出了一个问题: ViT的性能是由于其强大的Transformer结构产生的,还是由于使用patch作为输入表示产生的 。

 

在论文中, 作者证明了PatchEmbedding对VIT的精度影响更大 ,并提出了一个非常简单的模型 ConvMixer ,在思想上类似于ViT和MLP-Mixer。模型直接将patch作为输入, 分离空间和通道尺寸的混合建模 , 并在整个网络中保持相同大小的分辨率 。

 

尽管ConvMixer的设计很简单,但是实验证明了ConvMixer在相似的参数计数和数据集大小方面优于ViT、MLP-Mixer及其一些变体,以及经典的视觉模型,如ResNet。

 

2. 模型构建

 

我们先导入需要用到的工具包

 

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

 

2.1 Patch Embedding

 

patchembedding 的主要功能是 对原始输入图像(h, w)划分图像块 。首先指定 每个图像块的size为(patch_size, patch_size) , 将每张图像划分出(h//patch_size, w//patch_size)个图像块 。

 

它的实现方法就是 通过一个 kernel_size 和 stride 都等于 patch_size 的卷积层来划分图像块 。

 

 

代码如下:

 

# ---------------------------------------------- #
#(1)patchembedding层
'''out_channel代表输出通道数, patch_size代表每个图像块的宽高'''
# ---------------------------------------------- #
def patchembed(inputs, out_channel, patch_size):
    
    # 卷积核大小为patch_size*patch_size,步长为patch_size的标准卷积划分图像块
    x = layers.Conv2D(filters = out_channel,   # 输出通道数
                      kernel_size = patch_size,  # 卷积核尺寸
                      strides = patch_size,  # 卷积步长
                      padding = 'same',  # 
                      use_bias = False)(inputs)
    # GELU激活函数、BN标准化
    x = layers.Activation('gelu')(x)
    x = layers.BatchNormalization()(x)
    return x

 

2.2 特征提取层

 

这里的特征提取层由三部分组成, 深度卷积(depthwise conv)、逐点卷积(pointwise conv)、残差连接(shortcut) 。如下图ConvMixer Layer所示。

 

关于深度可分离卷积的原理,看我这篇博文: https://blog.csdn.net/dgvv4/article/details/123476899

 

 

首先输入特征图,经过 深度卷积提取特征图长宽方向的信息 ,其中卷积核的个数和输入特征图的通道数相同,且 输入和输出特征图的shape相同 ;接着残差连接输入和输出;然后经过 1*1逐点卷积融合通道方向的信息 ,其 中卷积核的个数和输出特征图的个数相同 。

 

代码如下:

 

# ---------------------------------------------- #
#(2)单个特征提取模块
'''out_channel代表逐点卷积的输出通道数, kernel_size代表深度卷积的卷积核大小'''
# ---------------------------------------------- #
def layer(inputs, out_channel, kernel_size):
    # 9*9深度卷积提取特征
    x = layers.DepthwiseConv2D(kernel_size = kernel_size,  # 卷积核大小
                               strides = 1,  # 不经过下采样
                               padding = 'same',  # 卷积前后size不变
                               use_bias = False)(inputs)
    # GELU激活、BN标准化
    x = layers.Activation('gelu')(x)
    x = layers.BatchNormalization()(x)
    # 残差连接
    x = x + inputs
    
    # 1*1逐点卷积
    x = layers.Conv2D(filters = out_channel,  # 输出通道数
                      kernel_size = 1,  # 1*1卷积
                      strides = 1)(x)
    # GELU激活、BN标准化
    x = layers.Activation('gelu')(x)
    x = layers.BatchNormalization()(x)
    return x
# ---------------------------------------------- #
#(3)堆叠多个特征提取模块
'''depth代表堆叠的次数'''
# ---------------------------------------------- #
def blocks(x, depth, out_channel, kernel_size):
    # 堆叠多个特征提取模块
    for _ in range(depth):
        x = layer(x, out_channel, kernel_size)
    
    return x

 

2.3 主干网络

 

ConvMixer的网络结构非常简单。首先图像经过 PatchEmbedding 划分图像块,然后经过12个特征提取模块,最后经过一个全连接层得到输出结果。

 

这里构建 ConvMixer-1536/20 网络模型 ,其中 1536 代表patchembedding 层的输出通道数 , 20 代表堆叠20个特征提取模块 , 每个图像块patch_size的大小为7*7 , 特征提取模块中深度卷积的卷积核尺寸为 9*9

 

 

代码如下:

 

# ---------------------------------------------- #
#(4)主干网络
'''input_shape代表输入图像的尺寸(不包含batch维度), num_classes代表分类数'''
# ---------------------------------------------- #
def convmixer(input_shape, num_classes):
    # 构造输入层[b,224,224,3]
    inputs = keras.Input(shape=input_shape)
    # patchembedding层[b,224//7,224//7,1536]
    x = patchembed(inputs, out_channel=1536, patch_size=7)
    # 经过20个特征提取层[b,224//7,224//7,1536]
    x = blocks(x, depth=20, out_channel=1536, kernel_size=9)
    # 全局平均池化[b,1536]
    x = layers.GlobalAveragePooling2D()(x)
    # 全连接分类[b,num_classes]
    outputs = layers.Dense(num_classes)(x)
    # 构造网络
    model = keras.Model(inputs, outputs)
    return model

 

2.4 查看网络架构

 

以1000分类为例查看网络结构

 

# ---------------------------------------------- #
#(5)查看网络结构
# ---------------------------------------------- #
if __name__ == '__main__':
    # 接受模型
    model = convmixer(input_shape=[224,224,3],num_classes=1000)
    # 查看网络结构
    model.summary()

 

网络结构如下:

 

conv2d_20 (Conv2D)             (None, 32, 32, 1536  2360832     ['tf.__operators__.add_19[0][0]']
                                )
 activation_40 (Activation)     (None, 32, 32, 1536  0           ['conv2d_20[0][0]']
                                )
 batch_normalization_40 (BatchN  (None, 32, 32, 1536  6144       ['activation_40[0][0]']
 ormalization)                  )
 global_average_pooling2d (Glob  (None, 1536)        0           ['batch_normalization_40[0][0]']
 alAveragePooling2D)
 dense (Dense)                  (None, 1000)         1537000     ['global_average_pooling2d[0][0]'
                                                                 ]
==================================================================================================
Total params: 51,719,656
Trainable params: 51,593,704
Non-trainable params: 125,952
__________________________________________________________________________________________________

Be First to Comment

发表评论

您的电子邮箱地址不会被公开。