本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.
1.模型介绍
该模型设计的思想就是利用attention机制,在普通ResNet网络中,增加侧分支,侧分支通过一系列的卷积和池化操作,逐渐提取高层特征并增大模型的感受野,前面已经说过高层特征的激活对应位置能够反映attention的区域,然后再对这种具有attention特征的feature map进行上采样,使其大小回到原始feature map的大小,就将attention对应到原始图片的每一个位置上,这个feature map叫做 attention map,与原来的feature map 进行element-wise product的操作,相当于一个权重器,增强有意义的特征,抑制无意义的信息。
论文中模型的结构如下图所示。
最上面红色箭头标记的流程,就是一个普通的残差网络,(这个分支其实为主干分支可以加传统的 resnet,ResNetXt,Inception 网络等)。然后在残差块的部分位置,加入另外的分支(即为灰色部分),构成一个整体的Attention Module,下面对Attention Module做具体分析。
一个Attention Module分为两个分支,右边的分支就是普通的卷积网络,即主干分支,叫做Trunk Branch。左边的分支是为了得到一个掩码mask,该掩码的作用是得到输入特征x的attention map,所以叫做Mask Branch,这个Mask Branch包含down sample和up sample的过程,目的是为了保证和右边分支的输出大小一致。
得到Attention map的mask以后,一个比较naive的方法就是直接用mask和主干分支进行一个element-wise product的操作,即M(x) * T(x),来对特征做一次权重操作。但是这样导致的问题就是:
M(x)的掩码是通过最后的sigmoid函数得到的,M(x)值在[0, 1]之间,连续多个Module模块直接相乘的话会导致feature map的值越来越小,同时也有可能打破原有网络的特性,使得网络的性能降低
于是就有了如下的改进: Attention Residual Learning
前面已经说了直接进行element-wise product操作会使得性能降低,那幺比较好的方式就借鉴ResNet恒等映射的方法:
其中M(x)为Soft Mask Branch的输出,F(x)为Trunk Branch的输出,那幺当M(x)=0时,该层的输入就等于F(x),因此该层的效果不可能比原始的F(x)差,这一点也借鉴了ResNet中恒等映射的思想,同时这样的加法,也使得Trunk Branch输出的feature map中显着的特征更加显着,增加了特征的判别性。此外, attention residual learning 既能很好地保留原始特征的特性,又能使原始特征具有绕过soft Mask Branch分支的能力,从而直接前馈(forward)到最顶层来削弱 mask 分支的特征筛选能力。经过这种残差结构的堆叠,能够很容易的将模型的深度达到很深的层次,具有非常好的性能。
2.代码实现
注意力模块:
def attention_block(input, input_channels=None, output_channels=None, encoder_depth=1): p = 1 t = 2 r = 1 if input_channels is None: input_channels = input.get_shape()[-1].value if output_channels is None: output_channels = input_channels # First Residual Block for i in range(p): input = residual_block(input) # Trunc Branch output_trunk = input for i in range(t): output_trunk = residual_block(output_trunk) # Soft Mask Branch ## encoder ### first down sampling output_soft_mask = MaxPool2D(padding='same')(input) # 32x32 for i in range(r): output_soft_mask = residual_block(output_soft_mask) skip_connections = [] for i in range(encoder_depth - 1): ## skip connections output_skip_connection = residual_block(output_soft_mask) skip_connections.append(output_skip_connection) # print ('skip shape:', output_skip_connection.get_shape()) ## down sampling output_soft_mask = MaxPool2D(padding='same')(output_soft_mask) for _ in range(r): output_soft_mask = residual_block(output_soft_mask) ## decoder skip_connections = list(reversed(skip_connections)) for i in range(encoder_depth - 1): ## upsampling for _ in range(r): output_soft_mask = residual_block(output_soft_mask) output_soft_mask = UpSampling2D()(output_soft_mask) ## skip connections output_soft_mask = Add()([output_soft_mask, skip_connections[i]]) ### last upsampling for i in range(r): output_soft_mask = residual_block(output_soft_mask) output_soft_mask = UpSampling2D()(output_soft_mask) ## Output output_soft_mask = Conv2D(input_channels, (1, 1))(output_soft_mask) output_soft_mask = Conv2D(input_channels, (1, 1))(output_soft_mask) output_soft_mask = Activation('sigmoid')(output_soft_mask) # Attention: (1 + output_soft_mask) * output_trunk output = Lambda(lambda x: x + 1)(output_soft_mask) output = Multiply()([output, output_trunk]) # # Last Residual Block for i in range(p): output = residual_block(output) return output
整个浅层的模型结构:
def AttentionResNet10(shape=(32, 32, 3), n_channels=32, n_classes=10): input_ = Input(shape=shape) x = Conv2D(n_channels, (5, 5), padding='same')(input_) x = BatchNormalization()(x) x = Activation('relu')(x) x = MaxPool2D(pool_size=(2, 2))(x) # 16x16 x = residual_block(x, input_channels=32, output_channels=128) x = attention_block(x, encoder_depth=2) x = residual_block(x, input_channels=128, output_channels=256, stride=2) # 8x8 x = attention_block(x, encoder_depth=1) x = residual_block(x, input_channels=256, output_channels=512, stride=2) # 4x4 x = attention_block(x, encoder_depth=1) x = residual_block(x, input_channels=512, output_channels=1024) x = residual_block(x, input_channels=1024, output_channels=1024) x = residual_block(x, input_channels=1024, output_channels=1024) x = AveragePooling2D(pool_size=(4, 4), strides=(1, 1))(x) # 1x1 x = Flatten()(x) output = Dense(n_classes, activation='softmax')(x) model = Model(input_, output) return model
模型调用函数:在这里调用封装好的CIFAR10图形识别数据,CIFAR10数据集共有60000张彩色图像,这些图像式32*32*3,分为10个类,每个类6000张。
import keras from IPython.display import SVG from keras.utils.vis_utils import model_to_dot from keras.datasets import cifar10 from keras.utils import to_categorical from keras.preprocessing.image import ImageDataGenerator from keras.callbacks import ReduceLROnPlateau, EarlyStopping from .models import AttentionResNet # 加载数据集 (x_train, y_train), (x_test, y_test) = cifar10.load_data() y_train = to_categorical(y_train) y_test = to_categorical(y_test) # define generators for training and validation data train_datagen = ImageDataGenerator( featurewise_center=True, featurewise_std_normalization=True, rotation_range=20, width_shift_range=0.2, height_shift_range=0.2, zoom_range=0.2, horizontal_flip=True) val_datagen = ImageDataGenerator( featurewise_center=True, featurewise_std_normalization=True) # 计算特征归一化所需的函数 # (std, mean, and principal components if ZCA whitening is applied) train_datagen.fit(x_train) val_datagen.fit(x_train) # build a model model = AttentionResNet(n_classes=10) # define loss, metrics, optimizer model.compile(keras.optimizers.Adam(lr=0.0001), loss='categorical_crossentropy', metrics=['accuracy']) # fits the model on batches with real-time data augmentation batch_size = 32 model.fit_generator(train_datagen.flow(x_train, y_train, batch_size=batch_size), steps_per_epoch=len(x_train)//batch_size, epochs=200, validation_data=val_datagen.flow(x_test, y_test, batch_size=batch_size), validation_steps=len(x_test)//batch_size, callbacks=callbacks, initial_epoch=0)
全部代码链接:
https://download.csdn.net/download/weixin_40651515/86309657
参考资料链接:
1. https://zhuanlan.zhihu.com/p/36838135
Be First to Comment