Press "Enter" to skip to content

手把手教程:如何从零开始训练 TF 模型并在安卓系统上运行

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

本教程介绍如何使用 tf.Keras 时序 API 从头开始训练模型,将 tf.Keras 模型转换为 tflite 格式,并在 Android 上运行该模型。我将以 MNIST 数据为例介绍图像分类,并分享一些你可能会面临的常见问题。本教程着重于端到端的体验,我不会深入探讨各种 tf.Keras API 或 Android 开发。

 

下载我的示例代码并执行以下操作:

 

在 colab 中运行:使用 tf.keras 的训练模型,并将 keras 模型转换为 tflite(链接到 Colab notebook)。

 

在 Android Studio 中运行:DigitRecognizer(链接到Android应用程序)。

 

 

1.训练自定义分类器

 

加载数据

 

我们将使用作为tf.keras框架一部分的mnst数据。

 

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

 

预处理数据

 

接下来,我们将输入图像从 28×28 变为 28x28x1 的形状,将其标准化,并对标签进行 one-hot 编码。

 

定义模型体系结构

 

然后我们将用 cnn 定义网络架构。

 

def create_model():  
 
        # Define the model architecture  
       model = keras.models.Sequential([  
              # Must define the input shape in the first layer of the neural network  
              keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation='relu', input_shape=(28,28,1)),  
              keras.layers.MaxPooling2D(pool_size=2),  
              keras.layers.Dropout(0.3),  
 
keras.layers.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu'),  
       keras.layers.MaxPooling2D(pool_size=2),  
       keras.layers.Dropout(0.3),  
 
keras.layers.Flatten(),  
       keras.layers.Dense(128, activation='relu'),  
       keras.layers.Dropout(0.5),  
       keras.layers.Dense(10, activation='softmax')  
])
 
# Compile the model  
model.compile(loss=keras.losses.categorical_crossentropy,  
       optimizer=keras.optimizers.Adam(),  
       metrics=['accuracy'])  
 
return model

 

训练模型

 

然后我们使用 model.fit()来训练模型。

 

model.fit(x_train,  
                y_train,  
               batch_size=64,  
               epochs=3,  
               validation_data=(x_test, y_test))

 

2.模型保存和转换

 

训练结束后,我们将保存一个 Keras 模型并将其转换为 TFLite 格式。

 

保存一个 Keras 模型

 

下面是保存 Keras 模型的方法-

 

# Save tf.keras model in HDF5 format  
keras_model = "mnist_keras_model.h5"  
keras.models.save_model(model, keras_model)

 

将keras模型转换为tflite

 

当使用 TFLite 转换器将 Keras 模型转换为 TFLite 格式时,有两个选择- 1)从命令行转换,或 2)直接在 python 代码中转换,这个更加推荐。

 

1)通过命令行转换

 

$ tflite_convert \  
$ --output_file=mymodel.tflite \  
$ --keras_model_file=mymodel.h5

 

2)通过 python 代码转换

 

如果你可以访问模型训练代码,则这是转换的首选方法。

 

# Convert the model  
flite_model = converter.convert()  
 
# Create the tflite model file  
tflite_model_name = "mymodel.tflite"  
open(tflite_model_name, "wb").write(tflite_model)

 

你可以将转换器的训练后量化设置为 true。

 

# Set quantize to true  
converter.post_training_quantize=True

 

验证转换的模型

 

将 Keras 模型转换为 TFLite 格式后,验证它是否能够与原始 Keras 模型一样正常运行是很重要的。请参阅下面关于如何使用 TFLite 模型运行推断的 python 代码片段。示例输入是随机输入数据,你需要根据自己的数据更新它。

 

# Load TFLite model and allocate tensors. interpreter = tf.contrib.lite.Interpreter(model_path="converted_model.tflite")  
interpreter.allocate_tensors()
 
# Get input and output tensors  
input_details = interpreter.get_input_details() output_details = interpreter.get_output_details()  
 
# Test model on random input data  
input_shape = input_details[0]['shape']  
input_data = np.array(np.random.random_sample(input_shape),  
dtype=np.float32) interpreter.set_tensor(input_details[0]['index'], input_data)  
interpreter.invoke()  
output_data = interpreter.get_tensor(output_details[0]['index'])  
print(output_data)

 

ps:确保在转换后和将 TFLite 模型放到 Android 上面之前始终测试它。否则,当它在你的 Android 应用程序上不能工作时,你无法分清是你的 android 代码有问题还是 ML 模型有问题。

 

3.在 Android 上实现 tflite 模型

 

现在我们准备在 Android 上实现 TFLite 模型。创建一个新的 Android 项目并遵循以下步骤

 

 

将 mnist.tflite 模型放在 assets 文件夹下

 

更新 build.gradle 以包含 tflite 依赖项

 

为用户创建自定义视图

 

创建一个进行数字分类的分类器

 

从自定义视图输入图像

 

图像预处理

 

用模型对图像进行分类

 

后处理

 

在用户界面中显示结果

 

 

Classifier 类是大多数 ML 魔术发生的地方。确保在类中设置的维度与模型预期的维度匹配:

 

28x28x1 的图像

 

10 位数字的 10 个类:0、1、2、3…9

 

要对图像进行分类,请执行以下步骤:

 

预处理输入图像。将位图转换为 bytebuffer 并将像素转换为灰度,因为 MNIST 数据集是灰度的。

 

使用由内存映射到 assets 文件夹下的模型文件创建的解释器运行推断。

 

后处理输出结果以在 UI 中显示。我们得到的结果有 10 种可能,我们将选择在 UI 中显示概率最高的数字。

 

 

过程中的挑战

 

以下是你可能遇到的挑战:

 

在 tflite 转换期间,如果出现「tflite 不支持某个操作」的错误,则应请求 tensorflow 团队添加该操作或自己创建自定义运算符。

 

有时,转换似乎是成功的,但转换后的模型却不起作用:例如,转换后的分类器可能在正负测试中以~0.5 的精度随机分类。(我在 tf 1.10 中遇到了这个错误,后来在 tf1.12 中修复了它)。

 

如果 Android 应用程序崩溃,请查看 logcat 中的 stacktrace 错误:

 

确保输入图像大小和颜色通道设置正确,以匹配模型期望的输入张量大小。

 

确保 in build.gradle aaptoptions 设置为不压缩 tflite 文件。

 

aaptOptions {  
         noCompress "tflite"  
}

 

总体来说,用 tf.Keras 训练一个简单的图像分类器是轻而易举的,保存 Keras 模型并将其转换为 TFLite 也相当容易。目前,我们在 Android 上实现 TFLite 模型的方法仍然有点单调,希望将来能有所改进。

 

via: https://medium.com/@margaretmz/e2e-tfkeras-tflite-android-273acde6588

 

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注