Press "Enter" to skip to content

使用pmml实现跨平台部署机器学习模型

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

一、概述

 

对于由Python训练的机器学习模型,通常有pickle和pmml两种部署方式,pickle方式用于在python环境中的部署,pmml方式用于跨平台(如Java环境)的部署,本文叙述的是pmml的跨平台部署方式。

 

PMML(Predictive Model Markup Language,预测模型标记语言)是一种基于XML描述来存储机器学习模型的标准语言。如,对在Python环境中由sklearn训练得到的模型,通过sklearn2pmml模块可将它完整地保存为一个pmml格式的文件,再在其他平台(如java)中加载该文件进行使用,从而实现模型的跨平台部署。

二、实现步骤

 

1.训练环境中安装生成pmml文件的工具。

 

如在Python环境中安装sklearn2pmml模块(pip install sklearn2pmml)。

 

2.训练模型。

 

3.将模型保存为pmml文件。

 

4.部署环境中导入依赖的工具包。

 

如在Java环境中导入pmml-evaluator、pmml-evaluator-extension(特殊情况下另加)、jaxb-core、jaxb-api、jaxb-impl等jar包。

 

5.开发应用,加载、使用模型。

 

注:对sklearn2pmml生成的pmml模型文件,在java中加载使用时,需将文件中的命名空间属性xmlns=”…/PMML-4_4″改为xmlns=”…/PMML-4_3″,以适应低版本的jar包对它的解析。

 

三、示例

 

在python中使用sklearn训练一个线性回归模型,并在java环境中部署使用。

 

工具:PyCharm-2017、Python-39、sklearn2pmml-0.76.1;IntelliJ IDEA-2018、jdk-14.0.2。

 

1.训练数据集 training_data.csv

2.训练、保存模型

 

import sklearn2pmml as pmml
from sklearn2pmml import PMMLPipeline
from sklearn import linear_model as lm
import os
import pandas as pd
def save_model(data, model_path):
    pipeline = PMMLPipeline([("regression", lm.LinearRegression())]) #定义模型,放入pipeline管道
    pipeline.fit(data[["x"]], data["y"]) #训练模型,由数据中第一行的名称确定自变量和因变量
    pmml.sklearn2pmml(pipeline, model_path, with_repr=True) #保存模型
if __name__ == "__main__":
    data = pd.read_csv("training_data.csv")
    model_path = model_path = os.path.dirname(os.path.abspath(__file__)) + "/my_example_model.pmml"
    save_model(data, model_path)
    print("模型保存完成。")

 

3.将pmml文件的xmlns属性修改为PMML-4_3

4.java程序中加载、使用模型

 

(1)创建maven项目,将pmml模型文件拷贝至项目根目录下。

 

(2)加入依赖包

 

<dependencies>
        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator</artifactId>
            <version>1.4.15</version>
        </dependency>
        <dependency>
            <groupId>com.sun.xml.bind</groupId>
            <artifactId>jaxb-core</artifactId>
            <version>2.2.11</version>
        </dependency>
        <dependency>
            <groupId>javax.xml</groupId>
            <artifactId>jaxb-api</artifactId>
            <version>2.1</version>
        </dependency>
        <dependency>
            <groupId>com.sun.xml.bind</groupId>
            <artifactId>jaxb-impl</artifactId>
            <version>2.2.11</version>
        </dependency>
    </dependencies>

 

(3)java程序加载模型完成预测

 

public class MLPmmlDeploy {
    public static void main(String[] args) {
        String model_path = "./my_example_model.pmml"; //模型路径
        int x = 20; //测试的自变量值
        Evaluator model = loadModel(model_path); //加载模型
        Object r = predict(model, x); //预测
        Double result = Double.parseDouble(r.toString());
        System.out.println("预测的结果为:" + result);
    }
    private static Evaluator loadModel(String model_path){
        PMML pmml = new PMML(); //定义PMML对象
        InputStream inputStream; //定义输入流
        try {
            inputStream = new FileInputStream(model_path); //输入流接到磁盘上的模型文件
            pmml = PMMLUtil.unmarshal(inputStream); //将输入流解析为PMML对象
        }catch (Exception e){
            e.printStackTrace();
        }
        ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance(); //实例化一个模型构造工厂
        Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml); //将PMML对象构造为Evaluator模型对象
        return evaluator;
    }
    private static Object predict(Evaluator evaluator, int x){
        Map<String, Integer> data = new HashMap<String, Integer>(); //定义测试数据Map,存入各元自变量
        data.put("x", x); //键"x"为自变量的名称,应与训练数据中的自变量名称一致
        List<InputField> inputFieldList = evaluator.getInputFields(); //得到模型各元自变量的属性列表
        Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
        for (InputField inputField : inputFieldList) { //遍历各元自变量的属性列表
            FieldName inputFieldName = inputField.getName();
            Object rawValue = data.get(inputFieldName.getValue()); //取出该元变量的值
            FieldValue inputFieldValue = inputField.prepare(rawValue); //将值加入该元自变量属性中
            arguments.put(inputFieldName, inputFieldValue); //变量名和变量值的对加入LinkedHashMap
        }
        Map<FieldName, ?> results = evaluator.evaluate(arguments); //进行预测
        List<TargetField> targetFieldList = evaluator.getTargetFields(); //得到模型各元因变量的属性列表
        FieldName targetFieldName = targetFieldList.get(0).getName(); //第一元因变量名称
        Object targetFieldValue = results.get(targetFieldName); //由因变量名称得到值
        return targetFieldValue;
    }
}

 

 

示例下载:

https://download.csdn.net/download/Albert201605/45645889

End.

 

Be First to Comment

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注