本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.
目录
1.任务介绍
数据结构为:
data ├── cat(文件夹含1000张图像) │ ├── chook(文件夹含1000张图像) │ ├── dog(文件夹含1000张图像) │ └── horse(文件夹含1000张图像)
需要把数据分成训练集train和验证集val,对train数据集进行训练,达到给定val数据集中的一张猫 / 狗的图片,识别其是猫还是狗的目的
2.数据处理
2.1.数据预处理
设置GPU环境进行训练:
import tensorflow as tf gpus = tf.config.list_physical_devices("GPU") if gpus: tf.config.experimental.set_memory_growth(gpus[0], True) #设置GPU显存用量按需使用 tf.config.set_visible_devices([gpus[0]],"GPU") # 打印显卡信息,确认GPU可用 print(gpus)
输出:
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
导入图片数据:
import matplotlib.pyplot as plt # 支持中文 plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 import os,PIL # 设置随机种子尽可能使结果可以重现 import numpy as np np.random.seed(1) # 设置随机种子尽可能使结果可以重现 import tensorflow as tf tf.random.set_seed(1) import pathlib data_dir = "./data" data_dir = pathlib.Path(data_dir) image_count = len(list(data_dir.glob('*/*'))) print("图片总数为:",image_count)
输出:
图片总数为: 4000
之后初始化参数,并使用 image_dataset_from_directory
方法将磁盘中的数据加载到 tf.data.Dataset
中
函数原型:
tf.keras.preprocessing.image_dataset_from_directory( directory, labels="inferred", label_mode="int", class_names=None, color_mode="rgb", batch_size=32, image_size=(256, 256), shuffle=True, seed=None, validation_split=None, subset=None, interpolation="bilinear", follow_links=False, )
官网介绍: tf.keras.utils.image_dataset_from_directory
代码:
batch_size = 4 img_height = 299 img_width = 299 train_ds = tf.keras.preprocessing.image_dataset_from_directory( data_dir, validation_split=0.2, subset="training", seed=12, image_size=(img_height, img_width), batch_size=batch_size)
输出:
Found 4000 files belonging to 4 classes. Using 3200 files for training.
同理配置验证集:
val_ds = tf.keras.preprocessing.image_dataset_from_directory( data_dir, validation_split=0.2, subset="validation", seed=12, image_size=(img_height, img_width), batch_size=batch_size)
输出:
Found 4000 files belonging to 4 classes. Using 800 files for validation.
我们可以通过 class_names
输出数据集的标签,标签将按字母顺序对应于目录名称
class_names = train_ds.class_names print(class_names)
输出:
['cat', 'chook', 'dog', 'horse']
查看batch的数据类型:
for image_batch, labels_batch in train_ds: print(image_batch.shape) print(labels_batch.shape) break
输出:
(4, 299, 299, 3) (4,)
2.2.可视化数据
plt.figure(figsize=(10, 5)) # 图形的宽为10高为5 plt.suptitle("数据展示") num = -1 for images, labels in train_ds.take(2): for i in range(4): num = num + 1 ax = plt.subplot(2, 4, num + 1) plt.imshow(images[i].numpy().astype("uint8")) plt.title(class_names[labels[i]]) plt.savefig('pic1.jpg', dpi=600) #指定分辨率保存 plt.axis("off")
输出:
2.3.配置数据集
shuffle()
: 打乱数据,详细可参考: 数据集shuffle方法中buffer_size的理解
prefetch()
:预取数据,加速运行,详细可参考: Better performance with the tf.data API
cache()
:将数据集缓存到内存当中,加速运行
AUTOTUNE = tf.data.AUTOTUNE train_ds = ( train_ds.cache() .shuffle(1000) # .map(train_preprocessing) # 这里可以设置预处理函数 # .batch(batch_size) # 在image_dataset_from_directory处已经设置了batch_size .prefetch(buffer_size=AUTOTUNE) ) val_ds = ( val_ds.cache() .shuffle(1000) # .map(val_preprocessing) # 这里可以设置预处理函数 # .batch(batch_size) # 在image_dataset_from_directory处已经设置了batch_size .prefetch(buffer_size=AUTOTUNE) )
2.网络设计
2.1.Xception简单介绍
详细可看: 知乎
论文地址: Xception: Deep Learning with Depthwise Separable Convolutions
工程代码: https://github.com/keras-team/keras-applications/blob/master/keras_applications/xception.py
Xception是Google2016年10月提出的,时间在Google家的MobileNet v1之后,MobileNet v2之前。其吸纳了ResNet、Inception、MobileNet v1的设计思想,直接以Inception v3为模子,将里面的基本Inception module的卷积替换为使用 Depthwise Separable Convolution,又外加了残差连接
Xception 的结构基于ResNet,整个网络被分为了三个部分: Entry
, Middle
和 Exit
Entry Middle Exit
网络的整个流程如下图,Xception架构有36个卷积层作为网络特征提取的基础,这36个卷积层被分为14个模块,除了第一个和最后一个,其他每一个模块都使用了残差连接
简而言之,Xception架构是一个深度可分离卷积层的线性叠加,这个架构易于修改,仅使用30-40行代码就可以完成
2.2.设计网络模型
#====================================# # Xception的网络部分 #====================================# from tensorflow.keras.preprocessing import image from tensorflow.keras.models import Model from tensorflow.keras import layers from tensorflow.keras.layers import Dense,Input,BatchNormalization,Activation,Conv2D,SeparableConv2D,MaxPooling2D from tensorflow.keras.layers import GlobalAveragePooling2D,GlobalMaxPooling2D from tensorflow.keras import backend as K from tensorflow.keras.applications.imagenet_utils import decode_predictions def Xception(input_shape = [299,299,3],classes=1000): img_input = Input(shape=input_shape) #=================# # Entry flow #=================# # block1 # 299,299,3 -> 149,149,64 x = Conv2D(32, (3, 3), strides=(2, 2), use_bias=False, name='block1_conv1')(img_input) x = BatchNormalization(name='block1_conv1_bn')(x) x = Activation('relu', name='block1_conv1_act')(x) x = Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x) x = BatchNormalization(name='block1_conv2_bn')(x) x = Activation('relu', name='block1_conv2_act')(x) # block2 # 149,149,64 -> 75,75,128 residual = Conv2D(128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x) residual = BatchNormalization()(residual) x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')(x) x = BatchNormalization(name='block2_sepconv1_bn')(x) x = Activation('relu', name='block2_sepconv2_act')(x) x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')(x) x = BatchNormalization(name='block2_sepconv2_bn')(x) x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block2_pool')(x) x = layers.add([x, residual]) # block3 # 75,75,128 -> 38,38,256 residual = Conv2D(256, (1, 1), strides=(2, 2),padding='same', use_bias=False)(x) residual = BatchNormalization()(residual) x = Activation('relu', name='block3_sepconv1_act')(x) x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')(x) x = BatchNormalization(name='block3_sepconv1_bn')(x) x = Activation('relu', name='block3_sepconv2_act')(x) x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')(x) x = BatchNormalization(name='block3_sepconv2_bn')(x) x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block3_pool')(x) x = layers.add([x, residual]) # block4 # 38,38,256 -> 19,19,728 residual = Conv2D(728, (1, 1), strides=(2, 2),padding='same', use_bias=False)(x) residual = BatchNormalization()(residual) x = Activation('relu', name='block4_sepconv1_act')(x) x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')(x) x = BatchNormalization(name='block4_sepconv1_bn')(x) x = Activation('relu', name='block4_sepconv2_act')(x) x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')(x) x = BatchNormalization(name='block4_sepconv2_bn')(x) x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block4_pool')(x) x = layers.add([x, residual]) #=================# # Middle flow #=================# # block5--block12 # 19,19,728 -> 19,19,728 for i in range(8): residual = x prefix = 'block' + str(i + 5) x = Activation('relu', name=prefix + '_sepconv1_act')(x) x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv1')(x) x = BatchNormalization(name=prefix + '_sepconv1_bn')(x) x = Activation('relu', name=prefix + '_sepconv2_act')(x) x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv2')(x) x = BatchNormalization(name=prefix + '_sepconv2_bn')(x) x = Activation('relu', name=prefix + '_sepconv3_act')(x) x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv3')(x) x = BatchNormalization(name=prefix + '_sepconv3_bn')(x) x = layers.add([x, residual]) #=================# # Exit flow #=================# # block13 # 19,19,728 -> 10,10,1024 residual = Conv2D(1024, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x) residual = BatchNormalization()(residual) x = Activation('relu', name='block13_sepconv1_act')(x) x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')(x) x = BatchNormalization(name='block13_sepconv1_bn')(x) x = Activation('relu', name='block13_sepconv2_act')(x) x = SeparableConv2D(1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')(x) x = BatchNormalization(name='block13_sepconv2_bn')(x) x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block13_pool')(x) x = layers.add([x, residual]) # block14 # 10,10,1024 -> 10,10,2048 x = SeparableConv2D(1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')(x) x = BatchNormalization(name='block14_sepconv1_bn')(x) x = Activation('relu', name='block14_sepconv1_act')(x) x = SeparableConv2D(2048, (3, 3), padding='same', use_bias=False, name='block14_sepconv2')(x) x = BatchNormalization(name='block14_sepconv2_bn')(x) x = Activation('relu', name='block14_sepconv2_act')(x) x = GlobalAveragePooling2D(name='avg_pool')(x) x = Dense(classes, activation='softmax', name='predictions')(x) inputs = img_input model = Model(inputs, x, name='xception') return model
打印模型信息:
model = Xception() # 打印模型信息 model.summary()
输出:
Model: "xception" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, 299, 299, 3 0 [] )] block1_conv1 (Conv2D) (None, 149, 149, 32 864 ['input_1[0][0]'] ) block1_conv1_bn (BatchNormaliz (None, 149, 149, 32 128 ['block1_conv1[0][0]'] ation) ) block1_conv1_act (Activation) (None, 149, 149, 32 0 ['block1_conv1_bn[0][0]'] ) block1_conv2 (Conv2D) (None, 147, 147, 64 18432 ['block1_conv1_act[0][0]'] ) block1_conv2_bn (BatchNormaliz (None, 147, 147, 64 256 ['block1_conv2[0][0]'] ation) ) block1_conv2_act (Activation) (None, 147, 147, 64 0 ['block1_conv2_bn[0][0]'] ) block2_sepconv1 (SeparableConv (None, 147, 147, 12 8768 ['block1_conv2_act[0][0]'] 2D) 8) block2_sepconv1_bn (BatchNorma (None, 147, 147, 12 512 ['block2_sepconv1[0][0]'] lization) 8) block2_sepconv2_act (Activatio (None, 147, 147, 12 0 ['block2_sepconv1_bn[0][0]'] n) 8) block2_sepconv2 (SeparableConv (None, 147, 147, 12 17536 ['block2_sepconv2_act[0][0]'] 2D) 8) block2_sepconv2_bn (BatchNorma (None, 147, 147, 12 512 ['block2_sepconv2[0][0]'] lization) 8) conv2d (Conv2D) (None, 74, 74, 128) 8192 ['block1_conv2_act[0][0]'] block2_pool (MaxPooling2D) (None, 74, 74, 128) 0 ['block2_sepconv2_bn[0][0]'] batch_normalization (BatchNorm (None, 74, 74, 128) 512 ['conv2d[0][0]'] alization) add (Add) (None, 74, 74, 128) 0 ['block2_pool[0][0]', 'batch_normalization[0][0]'] block3_sepconv1_act (Activatio (None, 74, 74, 128) 0 ['add[0][0]'] n) block3_sepconv1 (SeparableConv (None, 74, 74, 256) 33920 ['block3_sepconv1_act[0][0]'] 2D) block3_sepconv1_bn (BatchNorma (None, 74, 74, 256) 1024 ['block3_sepconv1[0][0]'] lization) block3_sepconv2_act (Activatio (None, 74, 74, 256) 0 ['block3_sepconv1_bn[0][0]'] n) block3_sepconv2 (SeparableConv (None, 74, 74, 256) 67840 ['block3_sepconv2_act[0][0]'] 2D) block3_sepconv2_bn (BatchNorma (None, 74, 74, 256) 1024 ['block3_sepconv2[0][0]'] lization) conv2d_1 (Conv2D) (None, 37, 37, 256) 32768 ['add[0][0]'] block3_pool (MaxPooling2D) (None, 37, 37, 256) 0 ['block3_sepconv2_bn[0][0]'] batch_normalization_1 (BatchNo (None, 37, 37, 256) 1024 ['conv2d_1[0][0]'] rmalization) add_1 (Add) (None, 37, 37, 256) 0 ['block3_pool[0][0]', 'batch_normalization_1[0][0]'] block4_sepconv1_act (Activatio (None, 37, 37, 256) 0 ['add_1[0][0]'] n) block4_sepconv1 (SeparableConv (None, 37, 37, 728) 188672 ['block4_sepconv1_act[0][0]'] 2D) block4_sepconv1_bn (BatchNorma (None, 37, 37, 728) 2912 ['block4_sepconv1[0][0]'] lization) block4_sepconv2_act (Activatio (None, 37, 37, 728) 0 ['block4_sepconv1_bn[0][0]'] n) block4_sepconv2 (SeparableConv (None, 37, 37, 728) 536536 ['block4_sepconv2_act[0][0]'] 2D) block4_sepconv2_bn (BatchNorma (None, 37, 37, 728) 2912 ['block4_sepconv2[0][0]'] lization) conv2d_2 (Conv2D) (None, 19, 19, 728) 186368 ['add_1[0][0]'] block4_pool (MaxPooling2D) (None, 19, 19, 728) 0 ['block4_sepconv2_bn[0][0]'] batch_normalization_2 (BatchNo (None, 19, 19, 728) 2912 ['conv2d_2[0][0]'] rmalization) add_2 (Add) (None, 19, 19, 728) 0 ['block4_pool[0][0]', 'batch_normalization_2[0][0]'] block5_sepconv1_act (Activatio (None, 19, 19, 728) 0 ['add_2[0][0]'] n) block5_sepconv1 (SeparableConv (None, 19, 19, 728) 536536 ['block5_sepconv1_act[0][0]'] 2D) block5_sepconv1_bn (BatchNorma (None, 19, 19, 728) 2912 ['block5_sepconv1[0][0]'] lization) block5_sepconv2_act (Activatio (None, 19, 19, 728) 0 ['block5_sepconv1_bn[0][0]'] n) block5_sepconv2 (SeparableConv (None, 19, 19, 728) 536536 ['block5_sepconv2_act[0][0]'] 2D) block5_sepconv2_bn (BatchNorma (None, 19, 19, 728) 2912 ['block5_sepconv2[0][0]'] lization) block5_sepconv3_act (Activatio (None, 19, 19, 728) 0 ['block5_sepconv2_bn[0][0]'] n) block5_sepconv3 (SeparableConv (None, 19, 19, 728) 536536 ['block5_sepconv3_act[0][0]'] 2D) block5_sepconv3_bn (BatchNorma (None, 19, 19, 728) 2912 ['block5_sepconv3[0][0]'] lization) add_3 (Add) (None, 19, 19, 728) 0 ['block5_sepconv3_bn[0][0]', 'add_2[0][0]'] block6_sepconv1_act (Activatio (None, 19, 19, 728) 0 ['add_3[0][0]'] n) block6_sepconv1 (SeparableConv (None, 19, 19, 728) 536536 ['block6_sepconv1_act[0][0]'] 2D) block6_sepconv1_bn (BatchNorma (None, 19, 19, 728) 2912 ['block6_sepconv1[0][0]'] lization) block6_sepconv2_act (Activatio (None, 19, 19, 728) 0 ['block6_sepconv1_bn[0][0]'] n) block6_sepconv2 (SeparableConv (None, 19, 19, 728) 536536 ['block6_sepconv2_act[0][0]'] 2D) block6_sepconv2_bn (BatchNorma (None, 19, 19, 728) 2912 ['block6_sepconv2[0][0]'] lization) block6_sepconv3_act (Activatio (None, 19, 19, 728) 0 ['block6_sepconv2_bn[0][0]'] n) block6_sepconv3 (SeparableConv (None, 19, 19, 728) 536536 ['block6_sepconv3_act[0][0]'] 2D) block6_sepconv3_bn (BatchNorma (None, 19, 19, 728) 2912 ['block6_sepconv3[0][0]'] lization) add_4 (Add) (None, 19, 19, 728) 0 ['block6_sepconv3_bn[0][0]', 'add_3[0][0]'] block7_sepconv1_act (Activatio (None, 19, 19, 728) 0 ['add_4[0][0]'] n) block7_sepconv1 (SeparableConv (None, 19, 19, 728) 536536 ['block7_sepconv1_act[0][0]'] 2D) block7_sepconv1_bn (BatchNorma (None, 19, 19, 728) 2912 ['block7_sepconv1[0][0]'] lization) block7_sepconv2_act (Activatio (None, 19, 19, 728) 0 ['block7_sepconv1_bn[0][0]'] n) block7_sepconv2 (SeparableConv (None, 19, 19, 728) 536536 ['block7_sepconv2_act[0][0]'] 2D) block7_sepconv2_bn (BatchNorma (None, 19, 19, 728) 2912 ['block7_sepconv2[0][0]'] lization) block7_sepconv3_act (Activatio (None, 19, 19, 728) 0 ['block7_sepconv2_bn[0][0]'] n) block7_sepconv3 (SeparableConv (None, 19, 19, 728) 536536 ['block7_sepconv3_act[0][0]'] 2D) block7_sepconv3_bn (BatchNorma (None, 19, 19, 728) 2912 ['block7_sepconv3[0][0]'] lization) add_5 (Add) (None, 19, 19, 728) 0 ['block7_sepconv3_bn[0][0]', 'add_4[0][0]'] block8_sepconv1_act (Activatio (None, 19, 19, 728) 0 ['add_5[0][0]'] n) block8_sepconv1 (SeparableConv (None, 19, 19, 728) 536536 ['block8_sepconv1_act[0][0]'] 2D) block8_sepconv1_bn (BatchNorma (None, 19, 19, 728) 2912 ['block8_sepconv1[0][0]'] lization) block8_sepconv2_act (Activatio (None, 19, 19, 728) 0 ['block8_sepconv1_bn[0][0]'] n) block8_sepconv2 (SeparableConv (None, 19, 19, 728) 536536 ['block8_sepconv2_act[0][0]'] 2D) block8_sepconv2_bn (BatchNorma (None, 19, 19, 728) 2912 ['block8_sepconv2[0][0]'] lization) block8_sepconv3_act (Activatio (None, 19, 19, 728) 0 ['block8_sepconv2_bn[0][0]'] n) block8_sepconv3 (SeparableConv (None, 19, 19, 728) 536536 ['block8_sepconv3_act[0][0]'] 2D) block8_sepconv3_bn (BatchNorma (None, 19, 19, 728) 2912 ['block8_sepconv3[0][0]'] lization) add_6 (Add) (None, 19, 19, 728) 0 ['block8_sepconv3_bn[0][0]', 'add_5[0][0]'] block9_sepconv1_act (Activatio (None, 19, 19, 728) 0 ['add_6[0][0]'] n) block9_sepconv1 (SeparableConv (None, 19, 19, 728) 536536 ['block9_sepconv1_act[0][0]'] 2D) block9_sepconv1_bn (BatchNorma (None, 19, 19, 728) 2912 ['block9_sepconv1[0][0]'] lization) block9_sepconv2_act (Activatio (None, 19, 19, 728) 0 ['block9_sepconv1_bn[0][0]'] n) block9_sepconv2 (SeparableConv (None, 19, 19, 728) 536536 ['block9_sepconv2_act[0][0]'] 2D) block9_sepconv2_bn (BatchNorma (None, 19, 19, 728) 2912 ['block9_sepconv2[0][0]'] lization) block9_sepconv3_act (Activatio (None, 19, 19, 728) 0 ['block9_sepconv2_bn[0][0]'] n) block9_sepconv3 (SeparableConv (None, 19, 19, 728) 536536 ['block9_sepconv3_act[0][0]'] 2D) block9_sepconv3_bn (BatchNorma (None, 19, 19, 728) 2912 ['block9_sepconv3[0][0]'] lization) add_7 (Add) (None, 19, 19, 728) 0 ['block9_sepconv3_bn[0][0]', 'add_6[0][0]'] block10_sepconv1_act (Activati (None, 19, 19, 728) 0 ['add_7[0][0]'] on) block10_sepconv1 (SeparableCon (None, 19, 19, 728) 536536 ['block10_sepconv1_act[0][0]'] v2D) block10_sepconv1_bn (BatchNorm (None, 19, 19, 728) 2912 ['block10_sepconv1[0][0]'] alization) block10_sepconv2_act (Activati (None, 19, 19, 728) 0 ['block10_sepconv1_bn[0][0]'] on) block10_sepconv2 (SeparableCon (None, 19, 19, 728) 536536 ['block10_sepconv2_act[0][0]'] v2D) block10_sepconv2_bn (BatchNorm (None, 19, 19, 728) 2912 ['block10_sepconv2[0][0]'] alization) block10_sepconv3_act (Activati (None, 19, 19, 728) 0 ['block10_sepconv2_bn[0][0]'] on) block10_sepconv3 (SeparableCon (None, 19, 19, 728) 536536 ['block10_sepconv3_act[0][0]'] v2D) block10_sepconv3_bn (BatchNorm (None, 19, 19, 728) 2912 ['block10_sepconv3[0][0]'] alization) add_8 (Add) (None, 19, 19, 728) 0 ['block10_sepconv3_bn[0][0]', 'add_7[0][0]'] block11_sepconv1_act (Activati (None, 19, 19, 728) 0 ['add_8[0][0]'] on) block11_sepconv1 (SeparableCon (None, 19, 19, 728) 536536 ['block11_sepconv1_act[0][0]'] v2D) block11_sepconv1_bn (BatchNorm (None, 19, 19, 728) 2912 ['block11_sepconv1[0][0]'] alization) block11_sepconv2_act (Activati (None, 19, 19, 728) 0 ['block11_sepconv1_bn[0][0]'] on) block11_sepconv2 (SeparableCon (None, 19, 19, 728) 536536 ['block11_sepconv2_act[0][0]'] v2D) block11_sepconv2_bn (BatchNorm (None, 19, 19, 728) 2912 ['block11_sepconv2[0][0]'] alization) block11_sepconv3_act (Activati (None, 19, 19, 728) 0 ['block11_sepconv2_bn[0][0]'] on) block11_sepconv3 (SeparableCon (None, 19, 19, 728) 536536 ['block11_sepconv3_act[0][0]'] v2D) block11_sepconv3_bn (BatchNorm (None, 19, 19, 728) 2912 ['block11_sepconv3[0][0]'] alization) add_9 (Add) (None, 19, 19, 728) 0 ['block11_sepconv3_bn[0][0]', 'add_8[0][0]'] block12_sepconv1_act (Activati (None, 19, 19, 728) 0 ['add_9[0][0]'] on) block12_sepconv1 (SeparableCon (None, 19, 19, 728) 536536 ['block12_sepconv1_act[0][0]'] v2D) block12_sepconv1_bn (BatchNorm (None, 19, 19, 728) 2912 ['block12_sepconv1[0][0]'] alization) block12_sepconv2_act (Activati (None, 19, 19, 728) 0 ['block12_sepconv1_bn[0][0]'] on) block12_sepconv2 (SeparableCon (None, 19, 19, 728) 536536 ['block12_sepconv2_act[0][0]'] v2D) block12_sepconv2_bn (BatchNorm (None, 19, 19, 728) 2912 ['block12_sepconv2[0][0]'] alization) block12_sepconv3_act (Activati (None, 19, 19, 728) 0 ['block12_sepconv2_bn[0][0]'] on) block12_sepconv3 (SeparableCon (None, 19, 19, 728) 536536 ['block12_sepconv3_act[0][0]'] v2D) block12_sepconv3_bn (BatchNorm (None, 19, 19, 728) 2912 ['block12_sepconv3[0][0]'] alization) add_10 (Add) (None, 19, 19, 728) 0 ['block12_sepconv3_bn[0][0]', 'add_9[0][0]'] block13_sepconv1_act (Activati (None, 19, 19, 728) 0 ['add_10[0][0]'] on) block13_sepconv1 (SeparableCon (None, 19, 19, 728) 536536 ['block13_sepconv1_act[0][0]'] v2D) block13_sepconv1_bn (BatchNorm (None, 19, 19, 728) 2912 ['block13_sepconv1[0][0]'] alization) block13_sepconv2_act (Activati (None, 19, 19, 728) 0 ['block13_sepconv1_bn[0][0]'] on) block13_sepconv2 (SeparableCon (None, 19, 19, 1024 752024 ['block13_sepconv2_act[0][0]'] v2D) ) block13_sepconv2_bn (BatchNorm (None, 19, 19, 1024 4096 ['block13_sepconv2[0][0]'] alization) ) conv2d_3 (Conv2D) (None, 10, 10, 1024 745472 ['add_10[0][0]'] ) block13_pool (MaxPooling2D) (None, 10, 10, 1024 0 ['block13_sepconv2_bn[0][0]'] ) batch_normalization_3 (BatchNo (None, 10, 10, 1024 4096 ['conv2d_3[0][0]'] rmalization) ) add_11 (Add) (None, 10, 10, 1024 0 ['block13_pool[0][0]', ) 'batch_normalization_3[0][0]'] block14_sepconv1 (SeparableCon (None, 10, 10, 1536 1582080 ['add_11[0][0]'] v2D) ) block14_sepconv1_bn (BatchNorm (None, 10, 10, 1536 6144 ['block14_sepconv1[0][0]'] alization) ) block14_sepconv1_act (Activati (None, 10, 10, 1536 0 ['block14_sepconv1_bn[0][0]'] on) ) block14_sepconv2 (SeparableCon (None, 10, 10, 2048 3159552 ['block14_sepconv1_act[0][0]'] v2D) ) block14_sepconv2_bn (BatchNorm (None, 10, 10, 2048 8192 ['block14_sepconv2[0][0]'] alization) ) block14_sepconv2_act (Activati (None, 10, 10, 2048 0 ['block14_sepconv2_bn[0][0]'] on) ) avg_pool (GlobalAveragePooling (None, 2048) 0 ['block14_sepconv2_act[0][0]'] 2D) predictions (Dense) (None, 1000) 2049000 ['avg_pool[0][0]'] ================================================================================================== Total params: 22,910,480 Trainable params: 22,855,952 Non-trainable params: 54,528 __________________________________________________________________________________________________
设置动态学习率
# 设置初始学习率 initial_learning_rate = 1e-4 lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate, decay_steps=300, # 敲黑板!!!这里是指 steps,不是指epochs decay_rate=0.96, # lr经过一次衰减就会变成 decay_rate*lr staircase=True) # 将指数衰减学习率送入优化器 optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
模型的编译
损失函数(loss):用于衡量模型在训练期间的准确率,这里用 sparse_categorical_crossentropy
,原理与 categorical_crossentropy
(多类交叉熵损失 )一样,不过真实值采用的整数编码(例如第0个类用数字0表示,第3个类用数字3表示,官方可看: tf.keras.losses.SparseCategoricalCrossentropy )
优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新,这里是 Adam
(官方可看: tf.keras.optimizers.Adam )
评价函数(metrics):用于监控训练和测试步骤,本次使用 accuracy
,即被正确分类的图像的比率(官方可看: tf.keras.metrics.Accuracy )
model.compile(optimizer=optimizer, loss ='sparse_categorical_crossentropy', metrics =['accuracy'])
训练模型
epochs = 20 history = model.fit( train_ds, validation_data=val_ds, epochs=epochs )
训练结果:
Epoch 1/20 800/800 [==============================] - 464s 564ms/step - loss: 1.4314 - accuracy: 0.4584 - val_loss: 1.0577 - val_accuracy: 0.5475 Epoch 2/20 800/800 [==============================] - 447s 559ms/step - loss: 0.9087 - accuracy: 0.6228 - val_loss: 0.8191 - val_accuracy: 0.6612 Epoch 3/20 800/800 [==============================] - 446s 558ms/step - loss: 0.6728 - accuracy: 0.7403 - val_loss: 0.8190 - val_accuracy: 0.6687 Epoch 4/20 800/800 [==============================] - 447s 559ms/step - loss: 0.3362 - accuracy: 0.8841 - val_loss: 0.8249 - val_accuracy: 0.6913 Epoch 5/20 800/800 [==============================] - 447s 559ms/step - loss: 0.1415 - accuracy: 0.9566 - val_loss: 0.9374 - val_accuracy: 0.6975 Epoch 6/20 800/800 [==============================] - 446s 558ms/step - loss: 0.0840 - accuracy: 0.9809 - val_loss: 1.2619 - val_accuracy: 0.6737 Epoch 7/20 800/800 [==============================] - 447s 558ms/step - loss: 0.0574 - accuracy: 0.9862 - val_loss: 0.7897 - val_accuracy: 0.7738 Epoch 8/20 800/800 [==============================] - 446s 558ms/step - loss: 0.0369 - accuracy: 0.9912 - val_loss: 0.8976 - val_accuracy: 0.7350 Epoch 9/20 800/800 [==============================] - 446s 557ms/step - loss: 0.0276 - accuracy: 0.9966 - val_loss: 0.7896 - val_accuracy: 0.7725 Epoch 10/20 800/800 [==============================] - 446s 558ms/step - loss: 0.0223 - accuracy: 0.9969 - val_loss: 0.7084 - val_accuracy: 0.7812 Epoch 11/20 800/800 [==============================] - 446s 558ms/step - loss: 0.0108 - accuracy: 0.9978 - val_loss: 0.8445 - val_accuracy: 0.7588 Epoch 12/20 800/800 [==============================] - 446s 557ms/step - loss: 0.0102 - accuracy: 0.9975 - val_loss: 0.7577 - val_accuracy: 0.7850 Epoch 13/20 800/800 [==============================] - 446s 558ms/step - loss: 0.0062 - accuracy: 0.9991 - val_loss: 0.7447 - val_accuracy: 0.7837 Epoch 14/20 800/800 [==============================] - 445s 557ms/step - loss: 0.0034 - accuracy: 0.9987 - val_loss: 1.0870 - val_accuracy: 0.7063 Epoch 15/20 800/800 [==============================] - 445s 557ms/step - loss: 0.0100 - accuracy: 0.9978 - val_loss: 0.8212 - val_accuracy: 0.7725 Epoch 16/20 800/800 [==============================] - 446s 557ms/step - loss: 0.0089 - accuracy: 0.9981 - val_loss: 0.8604 - val_accuracy: 0.7688 Epoch 17/20 800/800 [==============================] - 446s 557ms/step - loss: 0.0068 - accuracy: 0.9984 - val_loss: 0.7941 - val_accuracy: 0.7887 Epoch 18/20 800/800 [==============================] - 446s 557ms/step - loss: 0.0037 - accuracy: 0.9994 - val_loss: 0.9039 - val_accuracy: 0.7650 Epoch 19/20 800/800 [==============================] - 446s 557ms/step - loss: 0.0013 - accuracy: 1.0000 - val_loss: 0.8278 - val_accuracy: 0.7812 Epoch 20/20 800/800 [==============================] - 446s 557ms/step - loss: 6.7889e-04 - accuracy: 1.0000 - val_loss: 0.8216 - val_accuracy: 0.7812
3.模型评估
3.1.准确率评估
Accuracy与Loss图
acc = history.history['accuracy'] val_acc = history.history['val_accuracy'] loss = history.history['loss'] val_loss = history.history['val_loss'] epochs_range = range(epochs) plt.figure(figsize=(12, 4)) plt.subplot(1, 2, 1) plt.plot(epochs_range, acc, label='Training Accuracy') plt.plot(epochs_range, val_acc, label='Validation Accuracy') plt.legend(loc='lower right') plt.title('Training and Validation Accuracy') plt.subplot(1, 2, 2) plt.plot(epochs_range, loss, label='Training Loss') plt.plot(epochs_range, val_loss, label='Validation Loss') plt.legend(loc='upper right') plt.title('Training and Validation Loss') plt.show()
3.2.绘制混淆矩阵
confusion_matrix()
介绍可看: sklearn.metrics.confusion_matrix
Seaborn
:基于 Matplotlib
核心库进行了更高阶的 API 封装,其优势在配色更加舒服、以及图形元素的样式更加细腻
定义一个绘制混淆矩阵图的函数 plot_cm
:
from sklearn.metrics import confusion_matrix import seaborn as sns import pandas as pd # 定义一个绘制混淆矩阵图的函数 def plot_cm(labels, predictions): # 生成混淆矩阵 conf_numpy = confusion_matrix(labels, predictions) # 将矩阵转化为 DataFrame conf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names) plt.figure(figsize=(8,7)) sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu") plt.title('混淆矩阵',fontsize=15) plt.ylabel('真实值',fontsize=14) plt.xlabel('预测值',fontsize=14) plt.savefig('pic3.jpg', dpi=600) #指定分辨率保存
输出:
保存模型:
# 保存模型 model.save('model/model.h5') # 加载模型 new_model = tf.keras.models.load_model('model/model.h5')
3.3.进行预测
plt.figure(figsize=(15, 7)) # 图形的宽为15高为7 plt.suptitle("预测结果展示") num = -1 for images, labels in val_ds.take(2): for i in range(4): num = num + 1 plt.subplots_adjust(left=None, bottom=None, right=None, top=None , wspace=0.2, hspace=0.2) if num >= 8: break ax = plt.subplot(2, 4, num + 1) # 显示图片 plt.imshow(images[i].numpy().astype("uint8")) # 需要给图片增加一个维度 img_array = tf.expand_dims(images[i], 0) # 使用模型预测图片中的人物 predictions = model.predict(img_array) plt.title("True value: {} predictive value: {}".format(class_names[labels[i]],class_names[np.argmax(predictions)])) plt.savefig('pic4.jpg', dpi=400) #指定分辨率保存 plt.axis("off")
结果:
Be First to Comment