Press "Enter" to skip to content

可堆叠的残差注意力模块用于图像分类(Residual Attention Network for Image Classification——代…

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

1.模型介绍

 

该模型设计的思想就是利用attention机制,在普通ResNet网络中,增加侧分支,侧分支通过一系列的卷积和池化操作,逐渐提取高层特征并增大模型的感受野,前面已经说过高层特征的激活对应位置能够反映attention的区域,然后再对这种具有attention特征的feature map进行上采样,使其大小回到原始feature map的大小,就将attention对应到原始图片的每一个位置上,这个feature map叫做 attention map,与原来的feature map 进行element-wise product的操作,相当于一个权重器,增强有意义的特征,抑制无意义的信息。

 

论文中模型的结构如下图所示。

 

 

最上面红色箭头标记的流程,就是一个普通的残差网络,(这个分支其实为主干分支可以加传统的 resnet,ResNetXt,Inception 网络等)。然后在残差块的部分位置,加入另外的分支(即为灰色部分),构成一个整体的Attention Module,下面对Attention Module做具体分析。

 

 

一个Attention Module分为两个分支,右边的分支就是普通的卷积网络,即主干分支,叫做Trunk Branch。左边的分支是为了得到一个掩码mask,该掩码的作用是得到输入特征x的attention map,所以叫做Mask Branch,这个Mask Branch包含down sample和up sample的过程,目的是为了保证和右边分支的输出大小一致。

 

得到Attention map的mask以后,一个比较naive的方法就是直接用mask和主干分支进行一个element-wise product的操作,即M(x) * T(x),来对特征做一次权重操作。但是这样导致的问题就是:

 

M(x)的掩码是通过最后的sigmoid函数得到的,M(x)值在[0, 1]之间,连续多个Module模块直接相乘的话会导致feature map的值越来越小,同时也有可能打破原有网络的特性,使得网络的性能降低

 

于是就有了如下的改进: Attention Residual Learning

 

前面已经说了直接进行element-wise product操作会使得性能降低,那幺比较好的方式就借鉴ResNet恒等映射的方法:

 

 

其中M(x)为Soft Mask Branch的输出,F(x)为Trunk Branch的输出,那幺当M(x)=0时,该层的输入就等于F(x),因此该层的效果不可能比原始的F(x)差,这一点也借鉴了ResNet中恒等映射的思想,同时这样的加法,也使得Trunk Branch输出的feature map中显着的特征更加显着,增加了特征的判别性。此外, attention residual learning 既能很好地保留原始特征的特性,又能使原始特征具有绕过soft Mask Branch分支的能力,从而直接前馈(forward)到最顶层来削弱 mask 分支的特征筛选能力。经过这种残差结构的堆叠,能够很容易的将模型的深度达到很深的层次,具有非常好的性能。

 

2.代码实现

 

注意力模块:

 

def attention_block(input, input_channels=None, output_channels=None, encoder_depth=1):
    p = 1
    t = 2
    r = 1
    if input_channels is None:
        input_channels = input.get_shape()[-1].value
    if output_channels is None:
        output_channels = input_channels
    # First Residual Block
    for i in range(p):
        input = residual_block(input)
    # Trunc Branch
    output_trunk = input
    for i in range(t):
        output_trunk = residual_block(output_trunk)
    # Soft Mask Branch
    ## encoder
    ### first down sampling
    output_soft_mask = MaxPool2D(padding='same')(input)  # 32x32
    for i in range(r):
        output_soft_mask = residual_block(output_soft_mask)
    skip_connections = []
    for i in range(encoder_depth - 1):
        ## skip connections
        output_skip_connection = residual_block(output_soft_mask)
        skip_connections.append(output_skip_connection)
        # print ('skip shape:', output_skip_connection.get_shape())
        ## down sampling
        output_soft_mask = MaxPool2D(padding='same')(output_soft_mask)
        for _ in range(r):
            output_soft_mask = residual_block(output_soft_mask)
            ## decoder
    skip_connections = list(reversed(skip_connections))
    for i in range(encoder_depth - 1):
        ## upsampling
        for _ in range(r):
            output_soft_mask = residual_block(output_soft_mask)
        output_soft_mask = UpSampling2D()(output_soft_mask)
        ## skip connections
        output_soft_mask = Add()([output_soft_mask, skip_connections[i]])
    ### last upsampling
    for i in range(r):
        output_soft_mask = residual_block(output_soft_mask)
    output_soft_mask = UpSampling2D()(output_soft_mask)
    ## Output
    output_soft_mask = Conv2D(input_channels, (1, 1))(output_soft_mask)
    output_soft_mask = Conv2D(input_channels, (1, 1))(output_soft_mask)
    output_soft_mask = Activation('sigmoid')(output_soft_mask)
    # Attention: (1 + output_soft_mask) * output_trunk
    output = Lambda(lambda x: x + 1)(output_soft_mask)
    output = Multiply()([output, output_trunk])  #
    # Last Residual Block
    for i in range(p):
        output = residual_block(output)
    return output

 

整个浅层的模型结构:

 

def AttentionResNet10(shape=(32, 32, 3), n_channels=32, n_classes=10):
    input_ = Input(shape=shape)
    x = Conv2D(n_channels, (5, 5), padding='same')(input_)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPool2D(pool_size=(2, 2))(x)  # 16x16
    x = residual_block(x, input_channels=32, output_channels=128)
    x = attention_block(x, encoder_depth=2)
    x = residual_block(x, input_channels=128, output_channels=256, stride=2)  # 8x8
    x = attention_block(x, encoder_depth=1)
    x = residual_block(x, input_channels=256, output_channels=512, stride=2)  # 4x4
    x = attention_block(x, encoder_depth=1)
    x = residual_block(x, input_channels=512, output_channels=1024)
    x = residual_block(x, input_channels=1024, output_channels=1024)
    x = residual_block(x, input_channels=1024, output_channels=1024)
    x = AveragePooling2D(pool_size=(4, 4), strides=(1, 1))(x)  # 1x1
    x = Flatten()(x)
    output = Dense(n_classes, activation='softmax')(x)
    model = Model(input_, output)
    return model

 

模型调用函数:在这里调用封装好的CIFAR10图形识别数据,CIFAR10数据集共有60000张彩色图像,这些图像式32*32*3,分为10个类,每个类6000张。

 

import keras
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
from keras.datasets import cifar10
from keras.utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ReduceLROnPlateau, EarlyStopping
from .models import AttentionResNet
# 加载数据集
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
# define generators for training and validation data
train_datagen = ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)
val_datagen = ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True)
# 计算特征归一化所需的函数
# (std, mean, and principal components if ZCA whitening is applied)
train_datagen.fit(x_train)
val_datagen.fit(x_train)
# build a model
model = AttentionResNet(n_classes=10)
# define loss, metrics, optimizer
model.compile(keras.optimizers.Adam(lr=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
# fits the model on batches with real-time data augmentation
batch_size = 32
model.fit_generator(train_datagen.flow(x_train, y_train, batch_size=batch_size),
                    steps_per_epoch=len(x_train)//batch_size, epochs=200,
                    validation_data=val_datagen.flow(x_test, y_test, batch_size=batch_size),
                    validation_steps=len(x_test)//batch_size,
                    callbacks=callbacks, initial_epoch=0)

 

全部代码链接:

 

https://download.csdn.net/download/weixin_40651515/86309657

 

参考资料链接:

 

1. https://zhuanlan.zhihu.com/p/36838135

 

2. https://arxiv.org/pdf/1704.06904.pdf

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。