Press "Enter" to skip to content

使用tensorflow model maker训练目标检测模型

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

一、环境配置

 

1.1 使用conda创建一个新的隔离环境

 

因为我用的是conda环境,所以又新建了一个专门tensorflow model maker的环境

 

# 创建环境
conda create -n tf_model_maker python=3.9
# 激活环境
conda activate tf_model_maker

 

# 退出当前环境
conda deactivate
# 删除环境使用
conda remove -n tf_model_maker --all

 

1.2 配置tensorflow model maker环境

 

apt -y install libportaudio2
pip install -q --use-deprecated=legacy-resolver tflite-model-maker
pip install -q pycocotools
pip install -q opencv-python-headless==4.1.2.30
pip uninstall -y tensorflow && pip install -q tensorflow==2.8.0

 

此处没有使用nightly
版本,不知道是有什幺bug,使用nightly版本有些库引用出问题了,所以换回非nightly版本

 

1.3 导包

 

import numpy as np
import os
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector
import tensorflow as tf
assert tf.__version__.startswith('2')
tf.get_logger().setLevel('ERROR')
from absl import logging
logging.set_verbosity(logging.ERROR)

 

执行输出:

 

/root/anaconda3/envs/env_tflite_model_maker/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

 

二、数据集整理

 

我使用的数据格式是coco格式的,已经处理成csv文件了,csv文件格式是:

 

filename,width,height,class,xmin,ymin,xmax,ymax
00232f5be5eb8a0f2c34a4a63f73d678.jpeg,683,1024,ball,224,756,511,1024
.....

 

目标csv数据格式:cloud.google.com/vision/auto…

 

set,path,label,xmin,ymin,,,xmax,ymax,,

 

TRAIN或者VAL或者TEST:训练数据、验证数据、测试数据标记

 

图片文件全路径:此处必须要用全路径

 

label:标记名称

 

图片中对象的边界框:

使用 2 个包含一组 x、y 坐标的顶点(如果这些点是矩形的对角点)(xmin
,ymin
,,,xmax
,ymax
,,)
或使用全部 4 个顶点 (xmin
,ymin
,xmax
,ymin
,xmax
,ymax
,xmin
,ymax
)

这些坐标必须是 0 到 1 范围内的浮点数,其中 0 表示最小 x 或 y 值,1 表示最大 x 或 y 值。

 

例如,(0,0) 表示左上角,(1,1) 表示右下角;整个图片的边界框表示为 (0,0,,,1,1,,) 或 (0,0,1,0,1,1,0,1)

 

TRAIN,/root/xxx/images/00232f5be5eb8a0f2c34a4a63f73d678.jpg,ball,0.3279648609077599,0.73828125,,,0.7481698389458272,1.0,,
VAL,/root/xxx/images/00232f5be5eb8a0f2c34a4a63f73d678.jpg,ball,0.3279648609077599,0.73828125,,,0.7481698389458272,1.0,,
TEST,/root/xxx/images/00232f5be5eb8a0f2c34a4a63f73d678.jpg,ball,0.3279648609077599,0.73828125,,,0.7481698389458272,1.0,,

 

数据处理代码:

 

import codecs
import csv
import cv2
import os
image_path = '/root/xxx/images/'
def makeData(old_file,new_file,key):
    file = open(new_file,'w')
    with file:
        w = csv.writer(file)
        with codecs.open(old_file, encoding='utf-8-sig') as f:
            for row in csv.DictReader(f, skipinitialspace=True):
                width=float(row['width'])
                height=float(row['height'])
                label=row['class']
                xmin=float(row['xmin'])/width
                ymin=float(row['ymin'])/height
                xmax=float(row['xmax'])/width
                ymax=float(row['ymax'])/height
                filename=row['filename']
                print(filename)
                img_path = os.path.join(image_path, filename)
                
                if os.path.exists(img_path) is True:
                    name = filename.replace(".jpeg","").replace(".jpg","")
                    save_path = os.path.join(image_path, name+".jpg")
                    img = cv2.imread(img_path)
                    cv2.imwrite(save_path,img)
                    new_row=[key,save_path,label,xmin,ymin,'','',xmax,ymax,'','',]
                    print(new_row)
                    w.writerow(new_row)

 

我拿到的图库中有些图片是直接修改的后缀,真实格式和后缀不同,也重新处理了一下,还有些图片不存在了,也过滤了一下

 

makeData('/root/xxx/train.csv',
         '/root/xxx/new_train.csv',
        'TRAIN')
makeData('/root/xxx/test.csv',
         '/root/xxx/new_test.csv',
        'TEST')
makeData('/root/xxx/test.csv',
         '/root/xxx/new_vaild.csv',
        'VAL')

 

然后我把new_train.csv、new_test.csv、new_vaild.csv中取了部分数据,手动合并到一个名为data.csv的文件里了

 

train_data,validation_data,test_data = object_detector.DataLoader.from_csv('/root/xxx/data.csv')

 

三、准备预训练模型

 

由于物体检测模型只支持EfficientDet
系列的模型,我试过EfficientDet-Lite2
发现在手机端的速度不是很理想,高端机差不多需要100ms左右识别出来,最终选择了速度更快的EfficientDet-Lite0

 

 

Model architectureSize(MB)*Latency(ms)**Average Precision***
EfficientDet-Lite04.43725.69%
EfficientDet-Lite15.84930.55%
EfficientDet-Lite27.26933.97%
EfficientDet-Lite311.411637.70%
EfficientDet-Lite419.926041.96%

 

** Size of the integer quantized models.
** Latency measured on Pixel 4 using 4 threads on CPU.
*** Average Precision is the mAP (mean Average Precision) on the COCO 2017 validation dataset.*

 

3.1、选择预训练模型

 

spec = model_spec.get('efficientdet_lite0')

 

此处在国内的服务器上是会提示超时报错终止,原因就是被墙了,所以要根据提示修改源码成镜像文件路径

 

3.2、修改源码

 

# 预训练模型配置文件
vim ~/anaconda3/envs/env_tflite_model_maker/lib/python3.9/site-packages/tensorflow_examples/lite/model_maker/core/task/model_spec/object_detector_spec.py
# 找到efficientdet_lite0_spec配置文件
efficientdet_lite0_spec = functools.partial(
    EfficientDetModelSpec,
    model_name='efficientdet-lite0',
    uri='https://tfhub.dev/tensorflow/efficientdet/lite0/feature-vector/1',
)
# 把uri换一下
efficientdet_lite0_spec = functools.partial(
    EfficientDetModelSpec,
    model_name='efficientdet-lite0',
    uri='https://storage.googleapis.com/tfhub-modules/tensorflow/efficientdet/lite0/feature-vector/1.tar.gz',
)

 

关键是替换uri
,再重新执行spec = model_spec.get(‘efficientdet_lite0’)

 

四、训练模型

 

model = object_detector.create(train_data, model_spec=spec, batch_size=8, train_whole_model=True, validation_data=validation_data)

 

Epoch 1/50
540/540 [==============================] - 253s 399ms/step - det_loss: 0.6041 - cls_loss: 0.3679 - box_loss: 0.0047 - reg_l2_loss: 0.0637 - loss: 0.6678 - learning_rate: 0.0090 - gradient_norm: 4.1991 - val_det_loss: 1.2947 - val_cls_loss: 0.8470 - val_box_loss: 0.0090 - val_reg_l2_loss: 0.0645 - val_loss: 1.3592
Epoch 2/50
540/540 [==============================] - 214s 397ms/step - det_loss: 0.3937 - cls_loss: 0.2513 - box_loss: 0.0028 - reg_l2_loss: 0.0651 - loss: 0.4588 - learning_rate: 0.0100 - gradient_norm: 3.2312 - val_det_loss: 0.3262 - val_cls_loss: 0.2136 - val_box_loss: 0.0023 - val_reg_l2_loss: 0.0656 - val_loss: 0.3918
Epoch 3/50
540/540 [==============================] - 213s 394ms/step - det_loss: 0.3450 - cls_loss: 0.2250 - box_loss: 0.0024 - reg_l2_loss: 0.0660 - loss: 0.4110 - learning_rate: 0.0099 - gradient_norm: 2.8205 - val_det_loss: 0.2999 - val_cls_loss: 0.2096 - val_box_loss: 0.0018 - val_reg_l2_loss: 0.0664 - val_loss: 0.3663
。。。。。

 

评估模型

 

model.evaluate(test_data)

 

输出:

 

{'AP': 0.82879966,
 'AP50': 0.9893871,
 'AP75': 0.9637165,
 'APs': 0.50417614,
 'APm': 0.83946806,
 'APl': 0.8315978,
 'ARmax1': 0.7818135,
 'ARmax10': 0.8720247,
 'ARmax100': 0.87727976,
 'ARs': 0.7034483,
 'ARm': 0.89498526,
 'ARl': 0.87662005,
 'AP_/ball': 0.82879966}

 

五、导出tflite模型

 

model.export(export_dir='/root/xxx/tf')

 

会在/root/xxx/tf文件夹下生成model.tflite
文件

 

评估模型:

 

model.evaluate_tflite('model.tflite', test_data)

 

输出

 

{'AP': 0.817586,
 'AP50': 0.98929125,
 'AP75': 0.95808136,
 'APs': 0.4901086,
 'APm': 0.8326331,
 'APl': 0.81800973,
 'ARmax1': 0.77594024,
 'ARmax10': 0.8460072,
 'ARmax100': 0.84688306,
 'ARs': 0.63793105,
 'ARm': 0.86342186,
 'ARl': 0.84720457,
 'AP_/ball': 0.817586}

 

可以看出导出tflite之后模型的识别度从0.82879966
下降到了0.817586
,也还算能接受

 

tflite模型测试:

 

# Imports
from tflite_support.task import vision
from tflite_support.task import core
from tflite_support.task import processor
# Initialization
base_options = core.BaseOptions(file_name='/root/xxx/tf/model.tflite')
detection_options = processor.DetectionOptions(max_results=2)
options = vision.ObjectDetectorOptions(base_options=base_options, detection_options=detection_options)
detector = vision.ObjectDetector.create_from_options(options)
# Alternatively, you can create an object detector in the following manner:
# detector = vision.ObjectDetector.create_from_file(model_path)
# Run inference
image = vision.TensorImage.create_from_file('/root/xxx/images/00232f5be5eb8a0f2c34a4a63f73d678.jpeg')
detection_result = detector.detect(image)
image = vision.TensorImage.create_from_file('/root/xxx/11.png')
detection_result = detector.detect(image)
print(detection_result)

 

资料

 

tensorflow.google.cn/lite/models…

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。