Press "Enter" to skip to content

TensorFlow2提升模型性能的几种有效方法

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

文章目录

 

这里介绍几种TensorFlow2提升模型性能的几种有效方法,如数据增强、使用经典网络、使用迁移学习、使用新架构等。

 

一、数据增强

 

数据增强是提升模型性能的常用方法之一,尤其当训练数据集较小时,通过数据增强方法,效果更加明确。通过数据增强不但可以增加数据量、丰富数据的多样性,从而有效提升模型的泛化能力。

 

TensorFlow2提供了多种数据增强方法,常用的几种方法为:

 

 有监督的数据增强:

 

1、使用tf.image的预处理

 

该方法一般基于dataset,然后把数据增强方法嵌入到map中

 

2、使用tf..preprocessing.image.ImageDataGenerator

 

利用实时数据增强技术生成批量的张量图像数据

 

使用tf.keras.preprocessing.image.ImageDataGenerator,获取数据源的方法主要有:

 

(1).flow(x, y)

 

(2).flow_from_directory(directory)

 

3、使用tf.keras.layers.experimental.preprocessing

 

更多信息可参考官网:https://keras.io/guides/preprocessing_layers/

 

 无监督的数据增强

 

通过模型学习数据的分布,随机生成与训练数据集分布一致的图片,代表方法,GAN。

 

GAN模型包含两个网络,一个是生成网络(G),一个是判别网络(D),基本原理如下:

 

(1)G是一个生成图片的网络,它接收随机的噪声z,通过噪声生成图片,记做G(z)。

 

(2)D是一个判别网络,判别一张图片是不是“真实的”,即是真实的图片,还是由G生成的图片。其架构图如下:

 

 

下载所有实例使用数据(提取码为:fg29)

具体实例

二、使用现代经典模型

 

现代经典模型主要有: VGG、GoogLeNet、Inception、Xception、ResNet、MobileNet、DenseNet、NASNet等,各种网咯结构可参考:

 

https://blog.csdn.net/Forrest97/article/details/105630719

 

具体实例(待续)

 

三、利用迁移方法

 

Tensorflow.keras.application下载已经训练好的模型,包括如下预训练模型:

 

VGG、GoogLeNet、Inception、Xception、ResNet、MobileNet、DenseNet、NASNet等

 

具体实例(待续)

 

四、使用新架构如Transformer等

 

Transformer)模型在NLP领域取得了SOTA成绩,目前人们正把Transformer引用到CV领域,在图像识别、图像分类等领域取得不俗的表现。在目标检测(如DETR)、图像分类(如ViT)、图像分割(如SETR)都取得不错的效果。这里我们重点介绍ViT(Vision Transformer)。如何把Transformer引入CV领域?需要对CV中的图像做哪些处理?

 

NLP处理的语言数据是序列化的,而CV中处理的图像数据是三维的(height、width和channels)。所以需要通过某种方法将图像这种三维数据转化为序列化的数据。

 

Vision Transformer将CV和NLP领域知识结合起来,对原始图片进行分块,然后展平成序列,输入进原始Transformer模型的编码器Encoder部分,最后接入一个全连接层对图片进行分类。具体步骤为:

 

(1)原始图片(H,W,C)进行分块(patches)

 

进行一个类似卷积操作,把原始图片进行分块,分块数=H*W/P*P(P为块的大小)。

 

(2)展平每块(Flatten the patches)

 

把每块展平为一维向量,大小为:P*P*C(对应实例中的patch_dims)

 

(3)把展平后的块映射为更低维的向量

 

通过一个全连接层,把展平后块映射为一个更低维的向量(对应实例中的patch embedding),其大小为D(对应实例中的projection_dim),这个维度在各层保存不变,主要为便于使用残差连接。

 

(4)在patch embedding基础上添加一个类标签

 

类似BERT的[class] token,在patch embedding的序列之前添加一个可学习的embedding向量xclass

 

(5)加上位置嵌入(Add positional embeddings)

 

低维的向量+位置嵌入,

 

(6)把(5)的结果作为标准transformer encoder的输入

 

(7)在一个大数据集上训练模型

 

(8)在下游数据集中,进行微调。

 

ViT架构图如下:

 

 

具体实例(待续)

 

五、练习

 

1、利用ResNet经典模型提升性能

 

2、使用keras的数据增强方法提升性能

 

3、viT中Transformer的输入数据的形状(shape)与哪些因素有关?包括哪些数据?如何生成Transformer的输入数据?

Be First to Comment

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注