## 1.模型介绍

M(x）的掩码是通过最后的sigmoid函数得到的，M(x)值在[0, 1]之间，连续多个Module模块直接相乘的话会导致feature map的值越来越小，同时也有可能打破原有网络的特性，使得网络的性能降低

## 2.代码实现

```def attention_block(input, input_channels=None, output_channels=None, encoder_depth=1):
p = 1
t = 2
r = 1
if input_channels is None:
input_channels = input.get_shape()[-1].value
if output_channels is None:
output_channels = input_channels
# First Residual Block
for i in range(p):
input = residual_block(input)
# Trunc Branch
output_trunk = input
for i in range(t):
output_trunk = residual_block(output_trunk)
## encoder
### first down sampling
for i in range(r):
skip_connections = []
for i in range(encoder_depth - 1):
## skip connections
skip_connections.append(output_skip_connection)
# print ('skip shape:', output_skip_connection.get_shape())
## down sampling
for _ in range(r):
## decoder
skip_connections = list(reversed(skip_connections))
for i in range(encoder_depth - 1):
## upsampling
for _ in range(r):
## skip connections
### last upsampling
for i in range(r):
## Output
# Attention: (1 + output_soft_mask) * output_trunk
output = Lambda(lambda x: x + 1)(output_soft_mask)
output = Multiply()([output, output_trunk])  #
# Last Residual Block
for i in range(p):
output = residual_block(output)
return output```

```def AttentionResNet10(shape=(32, 32, 3), n_channels=32, n_classes=10):
input_ = Input(shape=shape)
x = Conv2D(n_channels, (5, 5), padding='same')(input_)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPool2D(pool_size=(2, 2))(x)  # 16x16
x = residual_block(x, input_channels=32, output_channels=128)
x = attention_block(x, encoder_depth=2)
x = residual_block(x, input_channels=128, output_channels=256, stride=2)  # 8x8
x = attention_block(x, encoder_depth=1)
x = residual_block(x, input_channels=256, output_channels=512, stride=2)  # 4x4
x = attention_block(x, encoder_depth=1)
x = residual_block(x, input_channels=512, output_channels=1024)
x = residual_block(x, input_channels=1024, output_channels=1024)
x = residual_block(x, input_channels=1024, output_channels=1024)
x = AveragePooling2D(pool_size=(4, 4), strides=(1, 1))(x)  # 1x1
x = Flatten()(x)
output = Dense(n_classes, activation='softmax')(x)
model = Model(input_, output)
return model```

```import keras
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
from keras.datasets import cifar10
from keras.utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ReduceLROnPlateau, EarlyStopping
from .models import AttentionResNet
# 加载数据集
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
# define generators for training and validation data
train_datagen = ImageDataGenerator(
featurewise_center=True,
featurewise_std_normalization=True,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
val_datagen = ImageDataGenerator(
featurewise_center=True,
featurewise_std_normalization=True)
# 计算特征归一化所需的函数
# (std, mean, and principal components if ZCA whitening is applied)
train_datagen.fit(x_train)
val_datagen.fit(x_train)
# build a model
model = AttentionResNet(n_classes=10)
# define loss, metrics, optimizer
# fits the model on batches with real-time data augmentation
batch_size = 32
model.fit_generator(train_datagen.flow(x_train, y_train, batch_size=batch_size),
steps_per_epoch=len(x_train)//batch_size, epochs=200,
validation_data=val_datagen.flow(x_test, y_test, batch_size=batch_size),
validation_steps=len(x_test)//batch_size,
callbacks=callbacks, initial_epoch=0)```