Press "Enter" to skip to content

DJL目标检测Demo

概述

 

Deep Java Library 是AWS在2019年推出的深度学习Java库,目前已经支持MXNet、PyTorch、TensorFlow模型的训练和推理。DJL没有和固定的深度学习框架绑定,因此同一套代码可以适配不同的深度学习框架。

 

这里根据官网给的教程,介绍如果搭建目标检测的Demo,实现的功能包括读取本地图片,加载官方Model Zoo提供的预训练模型、进行模型推理、输出目标检测的结果图到本地。参考资料包括: SSD模型推理的官方教程Maven依赖配置DJL Maven的BOM配置DJL版本依赖项搭配

 

工程搭建

 

新建Maven项目

 

JDK>=1.8

 

项目结构

 

 

依赖引入

 

引入djl本身以及djl依赖的其他包。

 

<build>
        <plugins>
            <!-->设定maven编译使用jdk8<-->
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <configuration>
                    <source>8</source>
                    <target>8</target>
                </configuration>
            </plugin>
        </plugins>
    </build>
    <!-->以BOM的方式统一管理依赖包的版本<-->
    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>ai.djl</groupId>
                <artifactId>bom</artifactId>
                <version>0.9.0</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>
    <dependencies>
        <dependency>
            <groupId>commons-cli</groupId>
            <artifactId>commons-cli</artifactId>
            <version>1.4</version>
        </dependency>
        <dependency>
            <groupId>com.google.code.gson</groupId>
            <artifactId>gson</artifactId>
            <version>2.8.5</version>
        </dependency>
        <!-->日志依赖包<-->
        <dependency>
            <groupId>org.apache.logging.log4j</groupId>
            <artifactId>log4j-slf4j-impl</artifactId>
            <version>2.12.1</version>
        </dependency>
        <!-->使用djl必须引入的依赖包<-->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
        </dependency>
        <!-->使用不同的深度学习框架模型引入不同的依赖包, 将mxnet改为pytorch即可更换深度学习框架
        Apache MXNet engine implementation<-->
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-engine</artifactId>
        </dependency>
        <!-->使用不同的深度学习框架模型引入不同的依赖包,
Apache MXNet native library<-->
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-native-auto</artifactId>
            <scope>runtime</scope>
        </dependency>
        <!-->使用不同的深度学习框架模型引入不同的依赖包, 
A ModelZoo containing models exported from Apache MXNet<-->
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-model-zoo</artifactId>
        </dependency>
    </dependencies>

 

日志配置文件

 

这部分没有固定要求,按实际需要来配置就可以,这里沿用从网上找的一份简单配置文件 log4j2.xml

 

<?xml version="1.0" encoding="UTF-8"?>
<Configuration status="WARN">
    <Properties>
        <property name="log_level" value="info" />
        <Property name="log_dir" value="log" />
        <property name="log_pattern"
                  value="[%d{yyyy-MM-dd HH:mm:ss.SSS}] [%p] - [%t] %logger - %m%n" />
        <property name="file_name" value="test" />
        <property name="every_file_size" value="100 MB" />
    </Properties>
    <Appenders>
        <Console name="Console" target="SYSTEM_OUT">
            <PatternLayout pattern="${log_pattern}" />
        </Console>
        <RollingFile name="RollingFile"
                     filename="${log_dir}/${file_name}.log"
                     filepattern="${log_dir}/$${date:yyyy-MM}/${file_name}-%d{yyyy-MM-dd}-%i.log">
            <ThresholdFilter level="DEBUG" onMatch="ACCEPT"
                             onMismatch="DENY" />
            <PatternLayout pattern="${log_pattern}" />
            <Policies>
                <SizeBasedTriggeringPolicy
                        size="${every_file_size}" />
                <TimeBasedTriggeringPolicy modulate="true"
                                           interval="1" />
            </Policies>
            <DefaultRolloverStrategy max="20" />
        </RollingFile>
        <RollingFile name="RollingFileErr"
                     fileName="${log_dir}/${file_name}-warnerr.log"
                     filePattern="${log_dir}/$${date:yyyy-MM}/${file_name}-%d{yyyy-MM-dd}-warnerr-%i.log">
            <ThresholdFilter level="WARN" onMatch="ACCEPT"
                             onMismatch="DENY" />
            <PatternLayout pattern="${log_pattern}" />
            <Policies>
                <SizeBasedTriggeringPolicy
                        size="${every_file_size}" />
                <TimeBasedTriggeringPolicy modulate="true"
                                           interval="1" />
            </Policies>
        </RollingFile>
    </Appenders>
    <Loggers>
        <Root level="${log_level}">
            <AppenderRef ref="Console" />
            <AppenderRef ref="RollingFile" />
            <appender-ref ref="RollingFileErr" />
        </Root>
    </Loggers>
</Configuration>

 

业务代码

 

引自:https://github.com/awslabs/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/ObjectDetection.java

 

package org.town;
import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
 * An example of inference using an object detection model.
 */public final class ObjectDetection {
    private static final Logger logger = LoggerFactory.getLogger(ObjectDetection.class);
    private ObjectDetection() {}
    public static void main(String[] args) throws IOException, ModelException, TranslateException {
        DetectedObjects detection = ObjectDetection.predict();
        logger.info("{}", detection);
    }
    public static DetectedObjects predict() throws IOException, ModelException, TranslateException {
        Path imageFile = Paths.get("src/test/resources/dog_bike_car.jpg");
        Image img = ImageFactory.getInstance().fromFile(imageFile);
        String backbone;
        if ("TensorFlow".equals(Engine.getInstance().getEngineName())) {
            backbone = "mobilenet_v2";
        } else {
            backbone = "resnet50";
        }
        Criteria<Image, DetectedObjects> criteria =
                Criteria.builder()
                        .optApplication(Application.CV.OBJECT_DETECTION)
                        .setTypes(Image.class, DetectedObjects.class)
                        .optFilter("backbone", backbone)
                        .optProgress(new ProgressBar())
                        .build();
        try (ZooModel<Image, DetectedObjects> model = ModelZoo.loadModel(criteria)) {
            try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
                DetectedObjects detection = predictor.predict(img);
                saveBoundingBoxImage(img, detection);
                return detection;
            }
        }
    }
    private static void saveBoundingBoxImage(Image img, DetectedObjects detection)
            throws IOException {
        Path outputDir = Paths.get("build/output");
        Files.createDirectories(outputDir);
        // Make image copy with alpha channel because original image was jpg
        Image newImage = img.duplicate(Image.Type.TYPE_INT_ARGB);
        newImage.drawBoundingBoxes(detection);
        Path imagePath = outputDir.resolve("detected-dog_bike_car.png");
        // OpenJDK can't save jpg with alpha channel
        newImage.save(Files.newOutputStream(imagePath), "png");
        logger.info("Detected objects image has been saved in: {}", imagePath);
    }
}

 

输出结果

 

 

问题记录

 

问题1:JDK版本过低导致Static interface method calls are not supported at language level ‘7’

 

 

解决办法:pom配置文件中显式指定使用jdk8进行编译。

 

问题2:Exception in thread “main” ai.djl.repository.zoo.ModelNotFoundException: No matching model with specified Input/Output type found.

 

 

解决办法:pom配置文件中没有引入ai.djl.mxnet:mxnet-model-zoo依赖包导致的,引入依赖即可。

 

问题3:切换为pytorch引擎后出错。

 

切换方式:

 

<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-engine</artifactId>
</dependency>
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-native-auto</artifactId>
    <scope>runtime</scope>
</dependency>
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-model-zoo</artifactId>
</dependency>

 

报错信息:

 

[2021-01-01 23:49:12.420] [WARN] - [main] ai.djl.engine.Engine - Failed to load engine from: ai.djl.pytorch.engine.PtEngineProvider
ai.djl.engine.EngineException: Failed to load PyTorch native library
at ai.djl.pytorch.engine.PtEngine.newInstance(PtEngine.java:56) ~[pytorch-engine-0.9.0.jar:?]
at ai.djl.pytorch.engine.PtEngineProvider.getEngine(PtEngineProvider.java:27) ~[pytorch-engine-0.9.0.jar:?]
at ai.djl.engine.Engine.initEngine(Engine.java:59) [api-0.9.0.jar:?]
at ai.djl.engine.Engine.<clinit>(Engine.java:49) [api-0.9.0.jar:?]
at org.town.ObjectDetection.predict(ObjectDetection.java:45) [classes/:?]
at org.town.ObjectDetection.main(ObjectDetection.java:36) [classes/:?]
Caused by: java.lang.UnsatisfiedLinkError: C:\Users\steel\.djl.ai\pytorch\1.7.0-cpu-win-x86_64\asmjit.dll: Can't find dependent libraries
at java.lang.ClassLoader$NativeLibrary.load(Native Method) ~[?:1.8.0_251]
at java.lang.ClassLoader.loadLibrary0(ClassLoader.java:1934) ~[?:1.8.0_251]
at java.lang.ClassLoader.loadLibrary(ClassLoader.java:1817) ~[?:1.8.0_251]
at java.lang.Runtime.load0(Runtime.java:809) ~[?:1.8.0_251]
at java.lang.System.load(System.java:1086) ~[?:1.8.0_251]
at java.util.stream.ForEachOps$ForEachOp$OfRef.accept(ForEachOps.java:184) ~[?:1.8.0_251]
at java.util.stream.ReferencePipeline$3$1.accept(ReferencePipeline.java:193) ~[?:1.8.0_251]
at java.util.stream.ReferencePipeline$2$1.accept(ReferencePipeline.java:175) ~[?:1.8.0_251]
at java.util.stream.ReferencePipeline$3$1.accept(ReferencePipeline.java:193) ~[?:1.8.0_251]
at java.util.Iterator.forEachRemaining(Iterator.java:116) ~[?:1.8.0_251]
at java.util.Spliterators$IteratorSpliterator.forEachRemaining(Spliterators.java:1801) ~[?:1.8.0_251]
at java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:482) ~[?:1.8.0_251]
at java.util.stream.AbstractPipeline.wrapAndCopyInto(AbstractPipeline.java:472) ~[?:1.8.0_251]
at java.util.stream.ForEachOps$ForEachOp.evaluateSequential(ForEachOps.java:151) ~[?:1.8.0_251]
at java.util.stream.ForEachOps$ForEachOp$OfRef.evaluateSequential(ForEachOps.java:174) ~[?:1.8.0_251]
at java.util.stream.AbstractPipeline.evaluate(AbstractPipeline.java:234) ~[?:1.8.0_251]
at java.util.stream.ReferencePipeline.forEach(ReferencePipeline.java:418) ~[?:1.8.0_251]
at ai.djl.pytorch.jni.LibUtils.loadWinDependencies(LibUtils.java:119) ~[pytorch-engine-0.9.0.jar:?]
at ai.djl.pytorch.jni.LibUtils.loadLibrary(LibUtils.java:75) ~[pytorch-engine-0.9.0.jar:?]
at ai.djl.pytorch.engine.PtEngine.newInstance(PtEngine.java:44) ~[pytorch-engine-0.9.0.jar:?]
... 5 more
Exception in thread "main" ai.djl.engine.EngineException: No deep learning engine found.
Please refer to https://github.com/awslabs/djl/blob/master/docs/development/troubleshooting.md for more details.
at ai.djl.engine.Engine.getInstance(Engine.java:119)
at org.town.ObjectDetection.predict(ObjectDetection.java:45)
at org.town.ObjectDetection.main(ObjectDetection.java:36)
Caused by: ai.djl.engine.EngineException: Failed to load PyTorch native library
at ai.djl.pytorch.engine.PtEngine.newInstance(PtEngine.java:56)
at ai.djl.pytorch.engine.PtEngineProvider.getEngine(PtEngineProvider.java:27)
at ai.djl.engine.Engine.initEngine(Engine.java:59)
at ai.djl.engine.Engine.<clinit>(Engine.java:49)
... 2 more
Caused by: java.lang.UnsatisfiedLinkError: C:\Users\steel\.djl.ai\pytorch\1.7.0-cpu-win-x86_64\asmjit.dll: Can't find dependent libraries
at java.lang.ClassLoader$NativeLibrary.load(Native Method)
at java.lang.ClassLoader.loadLibrary0(ClassLoader.java:1934)
at java.lang.ClassLoader.loadLibrary(ClassLoader.java:1817)
at java.lang.Runtime.load0(Runtime.java:809)
at java.lang.System.load(System.java:1086)
at java.util.stream.ForEachOps$ForEachOp$OfRef.accept(ForEachOps.java:184)
at java.util.stream.ReferencePipeline$3$1.accept(ReferencePipeline.java:193)
at java.util.stream.ReferencePipeline$2$1.accept(ReferencePipeline.java:175)
at java.util.stream.ReferencePipeline$3$1.accept(ReferencePipeline.java:193)
at java.util.Iterator.forEachRemaining(Iterator.java:116)
at java.util.Spliterators$IteratorSpliterator.forEachRemaining(Spliterators.java:1801)
at java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:482)
at java.util.stream.AbstractPipeline.wrapAndCopyInto(AbstractPipeline.java:472)
at java.util.stream.ForEachOps$ForEachOp.evaluateSequential(ForEachOps.java:151)
at java.util.stream.ForEachOps$ForEachOp$OfRef.evaluateSequential(ForEachOps.java:174)
at java.util.stream.AbstractPipeline.evaluate(AbstractPipeline.java:234)
at java.util.stream.ReferencePipeline.forEach(ReferencePipeline.java:418)
at ai.djl.pytorch.jni.LibUtils.loadWinDependencies(LibUtils.java:119)
at ai.djl.pytorch.jni.LibUtils.loadLibrary(LibUtils.java:75)
at ai.djl.pytorch.engine.PtEngine.newInstance(PtEngine.java:44)
... 5 more

 

解决办法:未找到。

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注