概述
Deep Java Library 是AWS在2019年推出的深度学习Java库,目前已经支持MXNet、PyTorch、TensorFlow模型的训练和推理。DJL没有和固定的深度学习框架绑定,因此同一套代码可以适配不同的深度学习框架。
这里根据官网给的教程,介绍如果搭建目标检测的Demo,实现的功能包括读取本地图片,加载官方Model Zoo提供的预训练模型、进行模型推理、输出目标检测的结果图到本地。参考资料包括: SSD模型推理的官方教程 、 Maven依赖配置 、 DJL Maven的BOM配置 、 DJL版本依赖项搭配 。
工程搭建
新建Maven项目
JDK>=1.8
项目结构
依赖引入
引入djl本身以及djl依赖的其他包。
<build> <plugins> <!-->设定maven编译使用jdk8<--> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> <configuration> <source>8</source> <target>8</target> </configuration> </plugin> </plugins> </build> <!-->以BOM的方式统一管理依赖包的版本<--> <dependencyManagement> <dependencies> <dependency> <groupId>ai.djl</groupId> <artifactId>bom</artifactId> <version>0.9.0</version> <type>pom</type> <scope>import</scope> </dependency> </dependencies> </dependencyManagement> <dependencies> <dependency> <groupId>commons-cli</groupId> <artifactId>commons-cli</artifactId> <version>1.4</version> </dependency> <dependency> <groupId>com.google.code.gson</groupId> <artifactId>gson</artifactId> <version>2.8.5</version> </dependency> <!-->日志依赖包<--> <dependency> <groupId>org.apache.logging.log4j</groupId> <artifactId>log4j-slf4j-impl</artifactId> <version>2.12.1</version> </dependency> <!-->使用djl必须引入的依赖包<--> <dependency> <groupId>ai.djl</groupId> <artifactId>api</artifactId> </dependency> <!-->使用不同的深度学习框架模型引入不同的依赖包, 将mxnet改为pytorch即可更换深度学习框架 Apache MXNet engine implementation<--> <dependency> <groupId>ai.djl.mxnet</groupId> <artifactId>mxnet-engine</artifactId> </dependency> <!-->使用不同的深度学习框架模型引入不同的依赖包, Apache MXNet native library<--> <dependency> <groupId>ai.djl.mxnet</groupId> <artifactId>mxnet-native-auto</artifactId> <scope>runtime</scope> </dependency> <!-->使用不同的深度学习框架模型引入不同的依赖包, A ModelZoo containing models exported from Apache MXNet<--> <dependency> <groupId>ai.djl.mxnet</groupId> <artifactId>mxnet-model-zoo</artifactId> </dependency> </dependencies>
日志配置文件
这部分没有固定要求,按实际需要来配置就可以,这里沿用从网上找的一份简单配置文件 log4j2.xml
<?xml version="1.0" encoding="UTF-8"?> <Configuration status="WARN"> <Properties> <property name="log_level" value="info" /> <Property name="log_dir" value="log" /> <property name="log_pattern" value="[%d{yyyy-MM-dd HH:mm:ss.SSS}] [%p] - [%t] %logger - %m%n" /> <property name="file_name" value="test" /> <property name="every_file_size" value="100 MB" /> </Properties> <Appenders> <Console name="Console" target="SYSTEM_OUT"> <PatternLayout pattern="${log_pattern}" /> </Console> <RollingFile name="RollingFile" filename="${log_dir}/${file_name}.log" filepattern="${log_dir}/$${date:yyyy-MM}/${file_name}-%d{yyyy-MM-dd}-%i.log"> <ThresholdFilter level="DEBUG" onMatch="ACCEPT" onMismatch="DENY" /> <PatternLayout pattern="${log_pattern}" /> <Policies> <SizeBasedTriggeringPolicy size="${every_file_size}" /> <TimeBasedTriggeringPolicy modulate="true" interval="1" /> </Policies> <DefaultRolloverStrategy max="20" /> </RollingFile> <RollingFile name="RollingFileErr" fileName="${log_dir}/${file_name}-warnerr.log" filePattern="${log_dir}/$${date:yyyy-MM}/${file_name}-%d{yyyy-MM-dd}-warnerr-%i.log"> <ThresholdFilter level="WARN" onMatch="ACCEPT" onMismatch="DENY" /> <PatternLayout pattern="${log_pattern}" /> <Policies> <SizeBasedTriggeringPolicy size="${every_file_size}" /> <TimeBasedTriggeringPolicy modulate="true" interval="1" /> </Policies> </RollingFile> </Appenders> <Loggers> <Root level="${log_level}"> <AppenderRef ref="Console" /> <AppenderRef ref="RollingFile" /> <appender-ref ref="RollingFileErr" /> </Root> </Loggers> </Configuration>
业务代码
引自:https://github.com/awslabs/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/ObjectDetection.java
package org.town; import ai.djl.Application; import ai.djl.ModelException; import ai.djl.engine.Engine; import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; import ai.djl.modality.cv.output.DetectedObjects; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelZoo; import ai.djl.repository.zoo.ZooModel; import ai.djl.training.util.ProgressBar; import ai.djl.translate.TranslateException; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * An example of inference using an object detection model. */public final class ObjectDetection { private static final Logger logger = LoggerFactory.getLogger(ObjectDetection.class); private ObjectDetection() {} public static void main(String[] args) throws IOException, ModelException, TranslateException { DetectedObjects detection = ObjectDetection.predict(); logger.info("{}", detection); } public static DetectedObjects predict() throws IOException, ModelException, TranslateException { Path imageFile = Paths.get("src/test/resources/dog_bike_car.jpg"); Image img = ImageFactory.getInstance().fromFile(imageFile); String backbone; if ("TensorFlow".equals(Engine.getInstance().getEngineName())) { backbone = "mobilenet_v2"; } else { backbone = "resnet50"; } Criteria<Image, DetectedObjects> criteria = Criteria.builder() .optApplication(Application.CV.OBJECT_DETECTION) .setTypes(Image.class, DetectedObjects.class) .optFilter("backbone", backbone) .optProgress(new ProgressBar()) .build(); try (ZooModel<Image, DetectedObjects> model = ModelZoo.loadModel(criteria)) { try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) { DetectedObjects detection = predictor.predict(img); saveBoundingBoxImage(img, detection); return detection; } } } private static void saveBoundingBoxImage(Image img, DetectedObjects detection) throws IOException { Path outputDir = Paths.get("build/output"); Files.createDirectories(outputDir); // Make image copy with alpha channel because original image was jpg Image newImage = img.duplicate(Image.Type.TYPE_INT_ARGB); newImage.drawBoundingBoxes(detection); Path imagePath = outputDir.resolve("detected-dog_bike_car.png"); // OpenJDK can't save jpg with alpha channel newImage.save(Files.newOutputStream(imagePath), "png"); logger.info("Detected objects image has been saved in: {}", imagePath); } }
输出结果
问题记录
问题1:JDK版本过低导致Static interface method calls are not supported at language level ‘7’
解决办法:pom配置文件中显式指定使用jdk8进行编译。
问题2:Exception in thread “main” ai.djl.repository.zoo.ModelNotFoundException: No matching model with specified Input/Output type found.
解决办法:pom配置文件中没有引入ai.djl.mxnet:mxnet-model-zoo依赖包导致的,引入依赖即可。
问题3:切换为pytorch引擎后出错。
切换方式:
<dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-engine</artifactId> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-auto</artifactId> <scope>runtime</scope> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-model-zoo</artifactId> </dependency>
报错信息:
[2021-01-01 23:49:12.420] [WARN] - [main] ai.djl.engine.Engine - Failed to load engine from: ai.djl.pytorch.engine.PtEngineProvider ai.djl.engine.EngineException: Failed to load PyTorch native library at ai.djl.pytorch.engine.PtEngine.newInstance(PtEngine.java:56) ~[pytorch-engine-0.9.0.jar:?] at ai.djl.pytorch.engine.PtEngineProvider.getEngine(PtEngineProvider.java:27) ~[pytorch-engine-0.9.0.jar:?] at ai.djl.engine.Engine.initEngine(Engine.java:59) [api-0.9.0.jar:?] at ai.djl.engine.Engine.<clinit>(Engine.java:49) [api-0.9.0.jar:?] at org.town.ObjectDetection.predict(ObjectDetection.java:45) [classes/:?] at org.town.ObjectDetection.main(ObjectDetection.java:36) [classes/:?] Caused by: java.lang.UnsatisfiedLinkError: C:\Users\steel\.djl.ai\pytorch\1.7.0-cpu-win-x86_64\asmjit.dll: Can't find dependent libraries at java.lang.ClassLoader$NativeLibrary.load(Native Method) ~[?:1.8.0_251] at java.lang.ClassLoader.loadLibrary0(ClassLoader.java:1934) ~[?:1.8.0_251] at java.lang.ClassLoader.loadLibrary(ClassLoader.java:1817) ~[?:1.8.0_251] at java.lang.Runtime.load0(Runtime.java:809) ~[?:1.8.0_251] at java.lang.System.load(System.java:1086) ~[?:1.8.0_251] at java.util.stream.ForEachOps$ForEachOp$OfRef.accept(ForEachOps.java:184) ~[?:1.8.0_251] at java.util.stream.ReferencePipeline$3$1.accept(ReferencePipeline.java:193) ~[?:1.8.0_251] at java.util.stream.ReferencePipeline$2$1.accept(ReferencePipeline.java:175) ~[?:1.8.0_251] at java.util.stream.ReferencePipeline$3$1.accept(ReferencePipeline.java:193) ~[?:1.8.0_251] at java.util.Iterator.forEachRemaining(Iterator.java:116) ~[?:1.8.0_251] at java.util.Spliterators$IteratorSpliterator.forEachRemaining(Spliterators.java:1801) ~[?:1.8.0_251] at java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:482) ~[?:1.8.0_251] at java.util.stream.AbstractPipeline.wrapAndCopyInto(AbstractPipeline.java:472) ~[?:1.8.0_251] at java.util.stream.ForEachOps$ForEachOp.evaluateSequential(ForEachOps.java:151) ~[?:1.8.0_251] at java.util.stream.ForEachOps$ForEachOp$OfRef.evaluateSequential(ForEachOps.java:174) ~[?:1.8.0_251] at java.util.stream.AbstractPipeline.evaluate(AbstractPipeline.java:234) ~[?:1.8.0_251] at java.util.stream.ReferencePipeline.forEach(ReferencePipeline.java:418) ~[?:1.8.0_251] at ai.djl.pytorch.jni.LibUtils.loadWinDependencies(LibUtils.java:119) ~[pytorch-engine-0.9.0.jar:?] at ai.djl.pytorch.jni.LibUtils.loadLibrary(LibUtils.java:75) ~[pytorch-engine-0.9.0.jar:?] at ai.djl.pytorch.engine.PtEngine.newInstance(PtEngine.java:44) ~[pytorch-engine-0.9.0.jar:?] ... 5 more Exception in thread "main" ai.djl.engine.EngineException: No deep learning engine found. Please refer to https://github.com/awslabs/djl/blob/master/docs/development/troubleshooting.md for more details. at ai.djl.engine.Engine.getInstance(Engine.java:119) at org.town.ObjectDetection.predict(ObjectDetection.java:45) at org.town.ObjectDetection.main(ObjectDetection.java:36) Caused by: ai.djl.engine.EngineException: Failed to load PyTorch native library at ai.djl.pytorch.engine.PtEngine.newInstance(PtEngine.java:56) at ai.djl.pytorch.engine.PtEngineProvider.getEngine(PtEngineProvider.java:27) at ai.djl.engine.Engine.initEngine(Engine.java:59) at ai.djl.engine.Engine.<clinit>(Engine.java:49) ... 2 more Caused by: java.lang.UnsatisfiedLinkError: C:\Users\steel\.djl.ai\pytorch\1.7.0-cpu-win-x86_64\asmjit.dll: Can't find dependent libraries at java.lang.ClassLoader$NativeLibrary.load(Native Method) at java.lang.ClassLoader.loadLibrary0(ClassLoader.java:1934) at java.lang.ClassLoader.loadLibrary(ClassLoader.java:1817) at java.lang.Runtime.load0(Runtime.java:809) at java.lang.System.load(System.java:1086) at java.util.stream.ForEachOps$ForEachOp$OfRef.accept(ForEachOps.java:184) at java.util.stream.ReferencePipeline$3$1.accept(ReferencePipeline.java:193) at java.util.stream.ReferencePipeline$2$1.accept(ReferencePipeline.java:175) at java.util.stream.ReferencePipeline$3$1.accept(ReferencePipeline.java:193) at java.util.Iterator.forEachRemaining(Iterator.java:116) at java.util.Spliterators$IteratorSpliterator.forEachRemaining(Spliterators.java:1801) at java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:482) at java.util.stream.AbstractPipeline.wrapAndCopyInto(AbstractPipeline.java:472) at java.util.stream.ForEachOps$ForEachOp.evaluateSequential(ForEachOps.java:151) at java.util.stream.ForEachOps$ForEachOp$OfRef.evaluateSequential(ForEachOps.java:174) at java.util.stream.AbstractPipeline.evaluate(AbstractPipeline.java:234) at java.util.stream.ReferencePipeline.forEach(ReferencePipeline.java:418) at ai.djl.pytorch.jni.LibUtils.loadWinDependencies(LibUtils.java:119) at ai.djl.pytorch.jni.LibUtils.loadLibrary(LibUtils.java:75) at ai.djl.pytorch.engine.PtEngine.newInstance(PtEngine.java:44) ... 5 more
解决办法:未找到。
Be First to Comment