Press "Enter" to skip to content

使用TensorFlow对象检测API训练自定义对象检测模型

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

出发地:https://makeoptim.com/en/deep-learning/yiai-object-detection https://makeoptim.com/en/deep-learning/yiai-object-detection

 

前言

 

本文将介绍对象检测的概念,并通过案例说明如何使用TensorFlow对象检测API训练一个自定义的对象检测器,包括数据集的采集和处理、TensorFlow对象检测API的安装、模型训练。 TensorFlow Object Detection API

 

案例效果如下图所示:

目标检测

 

如上图所示,图像分类解决的问题是图片中的对象是什幺,而对象检测可以识别图片中的对象和对象的位置(坐标)。

 

位置

 

目标检测的位置信息一般有两种格式:

极坐标(xmin,ymin,xmax,ymax):xmin,ymin:x,y坐标的最小值;xmin,ymin:x,y坐标的最大值
CENTER POINT:(x_center,y_center,w,h):x_center,y_center:目标检测框的中心点坐标;w,h:目标检测框的宽度和高度

里程碑

 

 

传统方法(区域建议+人工特征提取+分类器)

 

HOG+支持向量机、DPM

 

地区建议书+CNN(两阶段)

 

R-CNN、SPP-NET、快速R-CNN、更快R-CNN

 

端到端(一阶段)

 

YOLO、固态硬盘

 

TensorFlow对象检测API

 

TensorFlow对象检测API是一个构建在TensorFlow之上的开源框架,它使得构建、训练和部署对象检测模型变得很容易。此外,TensorFlow对象检测API还提供了Model Zoo,以方便我们选择和切换预先训练的模型。 TensorFlow Object Detection API TensorFlow Object Detection API Model Zoo

 

安装依赖项

孔达
协议

使用以下命令检查安装是否成功。

 

$ conda --version
conda 4.9.2
$ protoc --version
libprotoc 3.17.1

 

安装API

 

TensorFlow对象检测API提供的官方安装步骤比较繁琐。作者编写了一个脚本,只需一步即可直接安装。 TensorFlow Object Detection API

 

执行git克隆https://github.com/CatchZeng/object-detection-api.git下载repo,然后转到repo所在的目录(以下称为ODA repo),如果您看到以下输出,则执行以下命令,表明安装成功。

 

$  conda create -n  od python=3.8.5 && conda activate od && make install

 

$ pip install --upgrade tf-models-official==2.4.0
$ pip install --upgrade tensorflow==2.4.1

 

创建工作区

 

$ conda activate od
$ conda env list
# conda environments:
#
od                    *  /Users/catchzeng/.conda/envs/od
tensorflow               /Users/catchzeng/.conda/envs/tensorflow
base                     /Users/catchzeng/miniconda3

 

转到ODA repo目录并执行以下命令以创建工作区目录结构。

 

$ make workspace-box SAVE_DIR=workspace NAME=test

 

数据集

 

图像

 

我喜欢喝茶。今天我将以杯子、茶壶、加湿器为例。

将采集到的图片放入项目目录中图片的三个子目录中。

注解

 

收集图片后,您需要对训练和评估集中的图像进行注释。

 

我们选择LabelImg作为注释工具。 LabelImg

 

按照安装说明安装LabelImg,然后执行labelImg选择要注释的Train和Val文件夹。 installation

批注完成后,将生成图片对应的XML批注文件,如下图所示:

 

workspace/test/images
├── test
│   ├── 15.jpg
│   └── 16.jpg
├── train
│   ├── 1.jpg
│   ├── 1.xml
│   ├── 10.jpg
│   ├── 10.xml
│   ├── 2.jpg
│   ├── 2.xml
│   ├── 3.jpg
│   ├── 3.xml
│   ├── 4.jpg
│   ├── 4.xml
│   ├── 5.jpg
│   ├── 5.xml
│   ├── 6.jpg
│   ├── 6.xml
│   ├── 7.jpg
│   ├── 7.xml
│   ├── 8.jpg
│   ├── 8.xml
│   ├── 9.jpg
│   └── 9.xml
└── val
    ├── 11.jpg
    ├── 11.xml
    ├── 12.jpg
    ├── 12.xml
    ├── 13.jpg
    ├── 13.xml
    ├── 14.jpg
    └── 14.xml

 

LabelMap

 

在文件夹workspace/test/notation下创建label_map.pbtxt,内容是模型需要识别的对象。

 

item {
    id: 1
    name: 'cup'
}

 

创建TFRecord

 

TensorFlow Object Detection接口仅支持TFRecord格式,因此需要对数据集进行转换。 TensorFlow Object Detection API TFRecord

 

转到工作区目录(cd workspace/test),然后执行make gen-tford,它将在Annotation文件夹中生成TFRecord格式的数据集。 TFRecord

 

$ make gen-tfrecord
python ../../scripts/preprocessing/generate_tfrecord.py \
        -x images/train \
        -l annotations/label_map.pbtxt \
        -o annotations/train.record
Successfully created the TFRecord file: annotations/train.record
python ../../scripts/preprocessing/generate_tfrecord.py \
        -x images/val \
        -l annotations/label_map.pbtxt \
        -o annotations/val.record
Successfully created the TFRecord file: annotations/val.record

 

模范训练

 

下载预先训练好的模型

 

从Model Zoo中选择合适的模型,下载并解压缩,然后将其放入工作区/测试/预先训练的模型中。 Model Zoo

 

如果选择SSD MobileNet V2 FPNLite 320×320,可以执行以下命令自动下载和解压缩

 

$ make dl-model

 

目录结构如下:

 

└── test
    └── pre-trained-models
        └── ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8
            ├── checkpoint
            ├── pipeline.config
            └── saved_model

 

配置培训管道

 

在Models目录中创建相应的模型文件夹,例如:ssd_mobilenet_v2_fpnlite_320x320,复制pre-trained-models/ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8/pipeline.config.

 

└── test
    ├── models
    │   └── ssd_mobilenet_v2_fpnlite_320x320
    │       └── pipeline.config
    └── pre-trained-models

 

其中,Pipeline.config需要根据项目进行修改,具体如下

 

model {
  ssd {
    num_classes: 3 # Modify to the number of objects that need to be identified.
    ......
}
train_config {
  batch_size: 8 # Here you need to adjust the size according to your own computer performance
  ......
  optimizer {
    momentum_optimizer {
      learning_rate {
        cosine_decay_learning_rate {
          learning_rate_base: 0.07999999821186066
          total_steps: 10000 # Modify to the total number of steps you want to train
          warmup_learning_rate: 0.026666000485420227
          warmup_steps: 1000
        }
      }
      momentum_optimizer_value: 0.8999999761581421
    }
    use_moving_average: false
  }
  fine_tune_checkpoint: "pre-trained-models/ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8/checkpoint/ckpt-0" # Modify the path to the pre-trained model
  num_steps: 10000 # Modify to the total number of steps you want to train
  startup_delay_steps: 0.0
  replicas_to_aggregate: 8
  max_number_of_boxes: 100
  unpad_groundtruth_tensors: false
  fine_tune_checkpoint_type: "detection" # Here needs to be modified to detection, because we are doing object detection
  fine_tune_checkpoint_version: V2
}
train_input_reader {
  label_map_path: "annotations/label_map.pbtxt" # Modify to the annotations path
  tf_record_input_reader {
    input_path: "annotations/train.record" # Modify the path to the training set
  }
}
eval_config {
  metrics_set: "coco_detection_metrics"
  use_moving_averages: false
}
eval_input_reader {
  label_map_path: "annotations/label_map.pbtxt" # Modify to the annotations path
  shuffle: false
  num_epochs: 1
  tf_record_input_reader {
    input_path: "annotations/val.record" # Modify the path to the evaluation set
  }
}

 

培训模式

 

$ make train

 

模型导出和转换

保存的模型

$进行导出

TFLite模型

$make export-lite

转换TFLite模型

$make Convert-Lite

量化TFLite模型

$make Convert-Quant-Lite

 

测试

 

在执行make export以导出模型之后,将测试图像放在image/test文件夹中,然后执行python test_images.py将带注释的图像输出到image/test_Annotated。

摘要

 

本文通过案例介绍了物体检测的全过程,希望能帮助您快速掌握培养自定义物体检测器的能力。

 

案例的代码和数据集已放置在https://github.com/CatchZeng/object-detection-api.中 https://github.com/CatchZeng/object-detection-api

 

下面的文章将向您介绍对象检测的原理、流行的对象检测网络和图像分割。这篇文章就到这里,下次见。

 

参考文献

https://github.com/tensorflow/models/tree/master/research/object_detection
https://arxiv.org/pdf/1905.05055.pdf
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md

Be First to Comment

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注