5、定义数据预处理及训练模型的一些超参数
7.1 构建多层感知器（MLP）
7.2 创建一个类似卷积层的patch层
7.3 查看由patch层随机生成的图像块
7.4构建patch 编码层（ encoding layer）

### 1、导入模型

```import os
import math
import numpy as np
import pickle as p
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from tensorflow.keras import layers
%matplotlib inline```

### 2、定义加载函数

```def load_CIFAR_data(data_dir):

images_train=[]
labels_train=[]
for i in range(5):
f=os.path.join(data_dir,'data_batch_%d' % (i+1))
images_train.append(image_batch)
labels_train.append(label_batch)
Xtrain=np.concatenate(images_train)
Ytrain=np.concatenate(labels_train)
del image_batch ,label_batch

# 返回训练集的图像和标签，测试集的图像和标签
return (Xtrain,Ytrain),(Xtest,Ytest)```

### 3、定义批量加载函数

```def load_CIFAR_batch(filename):
""" load single batch of cifar """
with open(filename, 'rb')as f:
# 一个样本由标签和图像数据组成
#  (3072=32x32x3)
# ...
#
images= data_dict[b'data']
labels = data_dict[b'labels']

# 把原始数据结构调整为: BCWH
images = images.reshape(10000, 3, 32, 32)
# tensorflow处理图像数据的结构：BWHC
# 把通道数据C移动到最后一个维度
images = images.transpose (0,2,3,1)

labels = np.array(labels)

return images, labels```

### 4、加载数据

```data_dir = r'C:\Users\wumg\jupyter-ipynb\data\cifar-10-batches-py'

```train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))```

### 5、定义数据预处理及训练模型的一些超参数

```num_classes = 10
input_shape = (32, 32, 3)

learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 10
image_size = 72  # We'll resize input images to this size
patch_size = 6  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
transformer_units = [
projection_dim * 2,
projection_dim,
]  # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [2048, 1024]  # Size of the dense layers of the final classifier```

### 6、定义数据增强模型

```data_augmentation = keras.Sequential(
[
layers.experimental.preprocessing.Normalization(),
layers.experimental.preprocessing.Resizing(image_size, image_size),
layers.experimental.preprocessing.RandomFlip("horizontal"),
layers.experimental.preprocessing.RandomRotation(factor=0.02),
layers.experimental.preprocessing.RandomZoom(
height_factor=0.2, width_factor=0.2
),
],
name="data_augmentation",
)
# 使预处理层的状态与正在传递的数据相匹配
#Compute the mean and the variance of the training data for normalization.

### 7.1 构建多层感知器（MLP）

```def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=tf.nn.gelu)(x)
x = layers.Dropout(dropout_rate)(x)
return x```

### 7.2 创建一个类似卷积层的patch层

```class Patches(layers.Layer):
def __init__(self, patch_size):
super(Patches, self).__init__()
self.patch_size = patch_size

def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches```

### 7.3 查看由patch层随机生成的图像块

```import matplotlib.pyplot as plt

plt.figure(figsize=(4, 4))
image = x_train[np.random.choice(range(x_train.shape[0]))]
plt.imshow(image.astype("uint8"))
plt.axis("off")

resized_image = tf.image.resize(
tf.convert_to_tensor([image]), size=(image_size, image_size)
)
patches = Patches(patch_size)(resized_image)
print(f"Image size: {image_size} X {image_size}")
print(f"Patch size: {patch_size} X {patch_size}")
print(f"Patches per image: {patches.shape[1]}")
print(f"Elements per patch: {patches.shape[-1]}")

n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):
ax = plt.subplot(n, n, i + 1)
patch_img = tf.reshape(patch, (patch_size, patch_size, 3))
plt.imshow(patch_img.numpy().astype("uint8"))
plt.axis("off")```

Image size: 72 X 72

Patch size: 6 X 6

Patches per image: 144

Elements per patch: 108

### 7.4构建patch 编码层（ encoding layer）

```class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
#一个全连接层，其输出维度为projection_dim，没有指明激活函数
self.projection = layers.Dense(units=projection_dim)
#定义一个嵌入层，这是一个可学习的层
#输入维度为num_patches，输出维度为projection_dim
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)

def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded```

### 7.5构建ViT模型

```def create_vit_classifier():
inputs = layers.Input(shape=input_shape)
# Augment data.
augmented = data_augmentation(inputs)
#augmented = augmented_train_batches(inputs)
# Create patches.
patches = Patches(patch_size)(augmented)
# Encode patches.
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

# Create multiple layers of the Transformer block.
for _ in range(transformer_layers):
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
# Create a multi-head attention layer.
)(x1, x1)
# Skip connection 1.
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
# Skip connection 2.

# Create a [batch_size, projection_dim] tensor.
representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
representation = layers.Flatten()(representation)
representation = layers.Dropout(0.5)(representation)
# Classify outputs.
logits = layers.Dense(num_classes)(features)
# Create the Keras model.
model = keras.Model(inputs=inputs, outputs=logits)
return model```

### 8、编译、训练模型

```def run_experiment(model):
learning_rate=learning_rate, weight_decay=weight_decay
)

model.compile(
optimizer=optimizer,
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
],
)

#checkpoint_filepath = r".\tmp\checkpoint"
checkpoint_filepath ="model_bak.hdf5"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
checkpoint_filepath,
monitor="val_accuracy",
save_best_only=True,
save_weights_only=True,
)

history = model.fit(
x=x_train,
y=y_train,
batch_size=batch_size,
epochs=num_epochs,
validation_split=0.1,
callbacks=[checkpoint_callback],
)

_, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

return history```

```vit_classifier = create_vit_classifier()
history = run_experiment(vit_classifier)```

Epoch 1/10

176/176 [==============================] – 68s 333ms/step – loss: 2.6394 – accuracy: 0.2501 – top-5-accuracy: 0.7377 – val_loss: 1.5331 – val_accuracy: 0.4580 – val_top-5-accuracy: 0.9092

Epoch 2/10

176/176 [==============================] – 58s 327ms/step – loss: 1.6359 – accuracy: 0.4150 – top-5-accuracy: 0.8821 – val_loss: 1.2714 – val_accuracy: 0.5348 – val_top-5-accuracy: 0.9464

Epoch 3/10

176/176 [==============================] – 58s 328ms/step – loss: 1.4332 – accuracy: 0.4839 – top-5-accuracy: 0.9210 – val_loss: 1.1633 – val_accuracy: 0.5806 – val_top-5-accuracy: 0.9616

Epoch 4/10

176/176 [==============================] – 58s 329ms/step – loss: 1.3253 – accuracy: 0.5280 – top-5-accuracy: 0.9349 – val_loss: 1.1010 – val_accuracy: 0.6112 – val_top-5-accuracy: 0.9572

Epoch 5/10

176/176 [==============================] – 58s 330ms/step – loss: 1.2380 – accuracy: 0.5626 – top-5-accuracy: 0.9411 – val_loss: 1.0212 – val_accuracy: 0.6400 – val_top-5-accuracy: 0.9690

Epoch 6/10

176/176 [==============================] – 58s 330ms/step – loss: 1.1486 – accuracy: 0.5945 – top-5-accuracy: 0.9520 – val_loss: 0.9698 – val_accuracy: 0.6602 – val_top-5-accuracy: 0.9718

Epoch 7/10

176/176 [==============================] – 58s 330ms/step – loss: 1.1208 – accuracy: 0.6060 – top-5-accuracy: 0.9558 – val_loss: 0.9215 – val_accuracy: 0.6724 – val_top-5-accuracy: 0.9790

Epoch 8/10

176/176 [==============================] – 58s 330ms/step – loss: 1.0643 – accuracy: 0.6248 – top-5-accuracy: 0.9621 – val_loss: 0.8709 – val_accuracy: 0.6944 – val_top-5-accuracy: 0.9768

Epoch 9/10

176/176 [==============================] – 58s 330ms/step – loss: 1.0119 – accuracy: 0.6446 – top-5-accuracy: 0.9640 – val_loss: 0.8290 – val_accuracy: 0.7142 – val_top-5-accuracy: 0.9784

Epoch 10/10

176/176 [==============================] – 58s 330ms/step – loss: 0.9740 – accuracy: 0.6615 – top-5-accuracy: 0.9666 – val_loss: 0.8175 – val_accuracy: 0.7096 – val_top-5-accuracy: 0.9806

313/313 [==============================] – 9s 27ms/step – loss: 0.8514 – accuracy: 0.7032 – top-5-accuracy: 0.9773

Test accuracy: 70.32%

Test top 5 accuracy: 97.73%

In [15]:

### 9、查看运行结果

```acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss =history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1.1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([-0.1,4.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()```