使用pmml實現跨平臺部署機器學習模型

歸去_來兮發表於2021-11-20

一、概述

  對於由Python訓練的機器學習模型,通常有pickle和pmml兩種部署方式,pickle方式用於在python環境中的部署,pmml方式用於跨平臺(如Java環境)的部署,本文敘述的是pmml的跨平臺部署方式。

  PMML(Predictive Model Markup Language,預測模型標記語言)是一種基於XML描述來儲存機器學習模型的標準語言。如,對在Python環境中由sklearn訓練得到的模型,通過sklearn2pmml模組可將它完整地儲存為一個pmml格式的檔案,再在其他平臺(如java)中載入該檔案進行使用,從而實現模型的跨平臺部署。


二、實現步驟

 1.訓練環境中安裝生成pmml檔案的工具。
  如在Python環境中安裝sklearn2pmml模組(pip install sklearn2pmml)。
 2.訓練模型。
 3.將模型儲存為pmml檔案。
 4.部署環境中匯入依賴的工具包。
  如在Java環境中匯入pmml-evaluator、pmml-evaluator-extension(特殊情況下另加)、jaxb-core、jaxb-api、jaxb-impl等jar包。
 5.開發應用,載入、使用模型。

:對sklearn2pmml生成的pmml模型檔案,在java中載入使用時,需將檔案中的名稱空間屬性xmlns=".../PMML-4_4"改為xmlns=".../PMML-4_3",以適應低版本的jar包對它的解析。


三、示例

  在python中使用sklearn訓練一個線性迴歸模型,並在java環境中部署使用。

工具:PyCharm-2017、Python-39、sklearn2pmml-0.76.1;IntelliJ IDEA-2018、jdk-14.0.2。

1.訓練資料集training_data.csv

x y
150 6450
200 7450
250 8450
300 9450
350 11450
400 15450
600 18450

2.訓練、儲存模型

import sklearn2pmml as pmml
from sklearn2pmml import PMMLPipeline
from sklearn import linear_model as lm
import os
import pandas as pd

def save_model(data, model_path):
    pipeline = PMMLPipeline([("regression", lm.LinearRegression())]) #定義模型,放入pipeline管道
    pipeline.fit(data[["x"]], data["y"]) #訓練模型,由資料中第一行的名稱確定自變數和因變數
    pmml.sklearn2pmml(pipeline, model_path, with_repr=True) #儲存模型

if __name__ == "__main__":
    data = pd.read_csv("training_data.csv")
    model_path = model_path = os.path.dirname(os.path.abspath(__file__)) + "/my_example_model.pmml"
    save_model(data, model_path)
    print("模型儲存完成。")

3.將pmml檔案的xmlns屬性修改為PMML-4_3


4.java程式中載入、使用模型
(1)建立maven專案,將pmml模型檔案拷貝至專案根目錄下。
(2)加入依賴包

<dependencies>
        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator</artifactId>
            <version>1.4.15</version>
        </dependency>
        <dependency>
            <groupId>com.sun.xml.bind</groupId>
            <artifactId>jaxb-core</artifactId>
            <version>2.2.11</version>
        </dependency>
        <dependency>
            <groupId>javax.xml</groupId>
            <artifactId>jaxb-api</artifactId>
            <version>2.1</version>
        </dependency>
        <dependency>
            <groupId>com.sun.xml.bind</groupId>
            <artifactId>jaxb-impl</artifactId>
            <version>2.2.11</version>
        </dependency>
    </dependencies>

(3)java程式載入模型完成預測

public class MLPmmlDeploy {
    public static void main(String[] args) {

        String model_path = "./my_example_model.pmml"; //模型路徑
        int x = 700; //測試的自變數值

        Evaluator model = loadModel(model_path); //載入模型
        Object r = predict(model, x); //預測

        Double result = Double.parseDouble(r.toString());
        System.out.println("預測的結果為:" + result);
    }

    private static Evaluator loadModel(String model_path){
        PMML pmml = new PMML(); //定義PMML物件
        InputStream inputStream; //定義輸入流
        try {
            inputStream = new FileInputStream(model_path); //輸入流接到磁碟上的模型檔案
            pmml = PMMLUtil.unmarshal(inputStream); //將輸入流解析為PMML物件
        }catch (Exception e){
            e.printStackTrace();
        }

        ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance(); //例項化一個模型構造工廠
        Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml); //將PMML物件構造為Evaluator模型物件

        return evaluator;
    }

    private static Object predict(Evaluator evaluator, int x){
        Map<String, Integer> data = new HashMap<String, Integer>(); //定義測試資料Map,存入各元自變數
        data.put("x", x); //鍵"x"為自變數的名稱,應與訓練資料中的自變數名稱一致
        List<InputField> inputFieldList = evaluator.getInputFields(); //得到模型各元自變數的屬性列表

        Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
        for (InputField inputField : inputFieldList) { //遍歷各元自變數的屬性列表
            FieldName inputFieldName = inputField.getName();
            Object rawValue = data.get(inputFieldName.getValue()); //取出該元變數的值
            FieldValue inputFieldValue = inputField.prepare(rawValue); //將值加入該元自變數屬性中
            arguments.put(inputFieldName, inputFieldValue); //變數名和變數值的對加入LinkedHashMap
        }

        Map<FieldName, ?> results = evaluator.evaluate(arguments); //進行預測
        List<TargetField> targetFieldList = evaluator.getTargetFields(); //得到模型各元因變數的屬性列表
        FieldName targetFieldName = targetFieldList.get(0).getName(); //第一元因變數名稱
        Object targetFieldValue = results.get(targetFieldName); //由因變數名稱得到值

        return targetFieldValue;
    }

}

示例下載:
https://download.csdn.net/download/Albert201605/45645889

End.


參考

  1. https://www.freesion.com/article/4628411548/
  2. https://www.cnblogs.com/pinard/p/9220199.html
  3. https://www.cnblogs.com/moonlightpoet/p/5533313.html

相關文章