Press "Enter" to skip to content

使用数据增强方法提升模型性能

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

文章目录

 

使用数据增强来提升模型的性能

 

1、导入模型

 

import os
import math
import numpy as np
import pickle as p
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
%matplotlib inline

 

2、定义加载函数

 

def load_CIFAR_data(data_dir):
    """load CIFAR data"""
 
    images_train=[]
    labels_train=[]
    for i in range(5):
        f=os.path.join(data_dir,'data_batch_%d' % (i+1))
        print('loading ',f)
        # 调用 load_CIFAR_batch( )获得批量的图像及其对应的标签
        image_batch,label_batch=load_CIFAR_batch(f)
        images_train.append(image_batch)
        labels_train.append(label_batch)
        Xtrain=np.concatenate(images_train)
        Ytrain=np.concatenate(labels_train)
        del image_batch ,label_batch
    
    Xtest,Ytest=load_CIFAR_batch(os.path.join(data_dir,'test_batch'))
    print('finished loadding CIFAR-10 data')
    
    # 返回训练集的图像和标签,测试集的图像和标签
return (Xtrain,Ytrain),(Xtest,Ytest)

 

3、定义批量加载函数

 

def load_CIFAR_batch(filename):
    """ load single batch of cifar """  
    with open(filename, 'rb')as f:
        # 一个样本由标签和图像数据组成
        #  (3072=32x32x3)
        # ...
        # 
        data_dict = p.load(f, encoding='bytes')
        images= data_dict[b'data']
        labels = data_dict[b'labels']
                
        # 把原始数据结构调整为: BCWH
        images = images.reshape(10000, 3, 32, 32)
        # tensorflow处理图像数据的结构:BWHC
        # 把通道数据C移动到最后一个维度
        images = images.transpose (0,2,3,1)
     
        labels = np.array(labels)
        
        return images, labels

 

4、加载数据

 

data_dir = r'C:\Users\wumg\jupyter-ipynb\data\cifar-10-batches-py'
(x_train,y_train),(x_test,y_test) = load_CIFAR_data(data_dir)

 

把数据转换为dataset格式

 

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))

 

5、定义数据增强方法

 

def convert(image, label):
    image = tf.image.convert_image_dtype(image, tf.float32)
    return image, label
 
def augment(image, label):
    image, label = convert(image, label)
    #image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize_with_crop_or_pad(image, 34,34) # 四周各加3
    image = tf.image.random_crop(image, size=[32,32,3]) # 随机裁剪成28*28大小
    image = tf.image.random_brightness(image, max_delta=0.5) # 随机增加亮度
    return image, label
 
batch_size = 64
 
augmented_train_batches = (train_dataset
                          #.take(num_examples)
                          .cache()
                         # .repeat()
                          .shuffle(5000)
                          .(augment, num_parallel_calls=tf.data.experimental.AUTOTUNE)
                          .batch(batch_size)
                          .prefetch(tf.data.experimental.AUTOTUNE))
 
non_augmented_train_batches = (train_dataset
                              .cache()
                             # .repeat()
                              .shuffle(5000)
                              .map(convert, num_parallel_calls=tf.data.experimental.AUTOTUNE)
                              .batch(batch_size)
                              .prefetch(tf.data.experimental.AUTOTUNE))
 
validation_batches = (test_dataset 
                     .map(convert, num_parallel_calls=tf.data.experimental.AUTOTUNE)
                     .batch(2*batch_size))

 

6、构建模型

 

class MyCNN(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.conv1 = tf.keras.layers.Conv2D(
            filters=32,             # 卷积层神经元(卷积核)数目
            kernel_size=[3, 3],     # 感受野大小
            padding='same',         # padding策略(vaild 或 same)
            activation=tf.nn.relu   # 激活函数
        )
        self.pool1 = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)
        self.conv2 = tf.keras.layers.Conv2D(
            filters=64,
            kernel_size=[3, 3],
            padding='same',
            activation=tf.nn.relu
        )
        self.pool2 = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)
        self.flatten = tf.keras.layers.Reshape(target_shape=(8 * 8 * 64,))
        self.dense1 = tf.keras.layers.Dense(units=256, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(units=10)
 
    def call(self, inputs):
        x = self.conv1(inputs)                  # [batch_size, 32, 32, 3]
        x = self.pool1(x)                       # [batch_size, 32, 32, 32]
        x = self.conv2(x)                       # [batch_size, 16, 16, 64]
        x = self.pool2(x)                       # [batch_size, 8, 8, 64]
        x = self.flatten(x)                     # [batch_size, 8 * 8 * 64]
        x = self.dense1(x)                      # [batch_size, 256]
        x = self.dense2(x)                      # [batch_size, 10]
        output = tf.nn.softmax(x)
        return output
    def model01(self):
        x = tf.keras.Input(shape=(32, 32, 3))
        return tf.keras.Model(inputs=[x], outputs=self.call(x))

 

生成实例

 

model_no_augment = MyCNN()

 

查看模型详细结构

 

model_no_augment.model01().summary()

 

7、编译模型

 

model_no_augment.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])

 

8、

 

为便于比较,这里先不使用数据增强方法

 

epochs = 10
history_non_augment = model_no_augment.fit(non_augmented_train_batches,epochs=epochs,validation_data=validation_batches)

 

9、查看运行结果

 

acc = history_non_augment.history['accuracy']
val_acc = history_non_augment.history['val_accuracy']
 
loss = history_non_augment.history['loss']
val_loss = history_non_augment.history['val_loss']
 
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1.1])
plt.title('Training and Validation Accuracy')
 
plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([-0.1,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

 

运行结果

 

 

10、使用数据增强方法

 

model_augment = MyCNN()
 
model_augment.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
 
history_with_augment = model_augment.fit(augmented_train_batches,epochs=epochs,validation_data=validation_batches)

 

11、查看使用数据增强的运行结果

 

acc = history_with_augment.history['accuracy']
val_acc = history_with_augment.history['val_accuracy']
 
loss = history_with_augment.history['loss']
val_loss = history_with_augment.history['val_loss']
 
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1.1])
plt.title('Training and Validation Accuracy')
 
plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([-0.1,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

 

运行结果

 

 

12、结果分析

 

从不使用数据增强与使用数据增强方法的结果可以看出,使用数据增强方法后,模型性能有提升(未使用数据增强的验证精度为71%,使用数据增强方法后,验证精度提升到74%),而且模型的泛化能力也有提高(使用数据增强方法后,训练与验证精度曲线靠得较近)。

Be First to Comment

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注