Press "Enter" to skip to content

深度学习中的Data Augmentation- Test Time Augmentation (TTA)

本文主要学习深度学习中常用的数据增强(Data Augmentation)技术以及在Keras中如何实现它。

 

Deep Learning中存在很多通过改变神经网络(Neural Network)的训练(Training)方式改善神经网络结果的方法,Data Augmentation是其中一种常用的方法。Data Augmentation是一种增加Training Dataset的数据量大小的方法,也是一种正则化(Regularization)技术,还可以让模型对轻微变化的输入数据更加鲁棒(Robust)。

 

Data Augmentation

 

Data Augmentation是对输入数据随机应用旋转(Rotation)、缩放(Zoom)、平移(Shift)、反转(Flip)等操作的处理过程。

 

通过Data Augmentation,模型(Model)可以从数据集(Dataset)中学到待识别对象更多通用的Features。

Example of Data Augmentation on the CIFAR10 dataset

Test Time Augmentation (TTA) 是一种测试阶段提升神经网络模型效果的Data Augmentation技术。

 

Test Time Augmentation

 

与Training Set上的Data Augmentation相同,Test Time Augmentation也是对测试数据集图片进行随机的修改,然后把这些图片喂给神经网络,将同一张图片对应生成的多个图片的预测结果的平均值作为最终的网络预测结果。

 

从一个例子更清楚的理解Test Time Augmentation。我们在CIFAR10上训练神经网络,并用如下的测试图片(Test Image)进行测试。

The test image (a boat)

下面是模型对测试图片(Test Image)的预测结果,数字代表每个Class的Confidence,Confidence最大的Class就是最终预测的分类结果。

该Test Image对应的Ground Truth Label如下:

可以看到,模型输出了错误的预测结果。模型认为测试图片(Test Image)是Cars,而不是Boat。

 

然后我们应用Test Time Augmentation,对测试图片进行5次轻微的修改,然后分别输送给神经网络进行分类预测。

Modified version of the test image

对应的5次的分类预测结果如下:

Prediction 1

Prediction 2

Prediction 3

Prediction 4

Prediction 5

对上述的5个预测结果求均值:

Average of the 5 predictions

我们看到预测均值给出了正确的答案,Cars分类对应的Confidence最大。为什幺这种的方式可以Work呢?这是因为通过随机修改图片、多次预测求均值,某种程度上消除了预测误差。单次预测的误差可能比较大,多次预测就可能把误差抵消掉。

 

Keras中实现Test Time Augmentation

 

Keras可以很容易的实现Test Time Augmentation。我们首先实现一个简单的卷积神经网络(Convolutional Neural Network),然后在CIFAR10上进行训练。

 

model = Sequential()
model.add(Conv2D(64,(3,3), activation='relu', input_shape=(32,32,3)))
model.add(Conv2D(128,(3,3), activation='relu'))
model.add(Conv2D(128,(3,3), activation='relu'))
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(10, activation='softmax'))
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

 

然后使用ImageDataGenerator在训练图片上进行Data Augmentation。

 

train_datagen = ImageDataGenerator(
        shear_range=0.1,
        zoom_range=0.1,
        horizontal_flip=True,
        rotation_range=10.,
        fill_mode='reflect',
        width_shift_range = 0.1, 
        height_shift_range = 0.1)
train_datagen.fit(x_train)

 

训练神经网络。

 

history = model.fit_generator(train_datagen.flow(x_train, y_train,
                              batch_size=bs),
                              epochs=15,
                              steps_per_epoch=len(x_train)/bs,
                              validation_data=(x_val, y_val))

 

训练模型在验证集图片上的最终Accuracy为: 0.7415。

 

下一步我们在验证集(Validation Images)图片上应用Test Time Augmentation,即对验证集使用与训练集相同的DataGenerator方法。

 

tta_steps = 10
predictions = []
for i in tqdm(range(tta_steps)):
    preds = model.predict_generator(train_datagen.flow(x_val, batch_size=bs, shuffle=False), steps = len(x_val)/bs)
    predictions.append(preds)
pred = np.mean(predictions, axis=0)
np.mean(np.equal(np.argmax(y_val, axis=-1), np.argmax(pred, axis=-1)))

 

应用Test Time Augmentation之后,模型的Accuray达到了0.7743。可以看到,不对模型做任何改变,仅仅应用Test Time Augmentation之后,模型的Accuray提升了3个点。

 

Data Augmentation并不总是有效

 

在很多场景下Data Augmentation是一种获得更好网络效果的有效技术。但是也要谨慎使用,因为在一些场景下,它会伤害模型的Accuracy。

Bad data augmentation on MNIST

比如在MNIST数据集中,显然你不能对图像进行随机反转或者旋转,因为6可能变成9,这会让你的网络模型非常的困惑。

 

在CIFAR10数据集中,你可以对图像进行水平的反转,因为这并不影响图像本身,一匹马不管从左侧看还是从右侧看,它都仍然是一匹马。但是垂直的反转就没有意义了,因为你大概率不想让你的神经网络识别一个上下反转的Ship。

 

还有一些情况下,比如卫星图片,无论怎幺上下左右反转都不改变图像本身的含义,所以你可以对这类图片做任意的旋转和上下左右反转。

 

总而言之,只要使用得当,Data Augmentation不仅可以在训练(Trainning)阶段提升神经网络效果,而且可以在测试(Testing)阶段提升网络的预测准确率。

 

参考材料

 

https://towardsdatascience.com/test-time-augmentation-tta-and-how-to-perform-it-with-keras-4ac19b67fb4d

 

除非注明,否则均为[半杯茶的小酒杯]原创文章,转载必须以链接形式标明本文链接

 

本文链接: http://www.banbeichadexiaojiubei.com/index.php/2020/12/19/%e6%b7%b1%e5%ba%a6%e5%ad%a6%e4%b9%a0%e4%b8%ad%e7%9a%84data-augmentation-test-time-augmentation-tta/

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注