使用pmml跨平臺部署機器學習模型Demo——房價預測

歸去_來兮發表於2021-11-21

  基於房價資料,在python中訓練得到一個線性迴歸的模型,在JavaWeb中載入模型完成房價預測的功能。


一、 訓練、儲存模型

工具:PyCharm-2017、Python-39、sklearn2pmml-0.76.1。

1.訓練資料house_price.csv

No square_feet price
1 150 6450
2 200 7450
3 250 8450
4 300 9450
5 350 11450
6 400 15450
7 600 18450

2.訓練、儲存模型

import sklearn2pmml as pmml
from sklearn2pmml import PMMLPipeline
from sklearn import linear_model as lm
import os
import pandas as pd

def save_model(data, model_path):
    pipeline = PMMLPipeline([("regression", lm.LinearRegression())])
    pipeline.fit(data[["square_feet"]], data["price"])
    pmml.sklearn2pmml(pipeline, model_path, with_repr=True)

if __name__ == "__main__":
    data = pd.read_csv("house_price.csv")
    model_path = model_path = os.path.dirname(os.path.abspath(__file__)) + "/my_liner_model.pmml"
    save_model(data, model_path)
    print("模型儲存完成。")

二、JavaWeb應用開發

工具:IntelliJ IDEA-2018、jdk-14.0.2、Tomcat-9.0.37。


建立maven專案,加入依賴項

    <dependencies>
        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator</artifactId>
            <version>1.4.15</version>
        </dependency>
        <dependency>
            <groupId>com.sun.xml.bind</groupId>
            <artifactId>jaxb-core</artifactId>
            <version>2.2.11</version>
        </dependency>
        <dependency>
            <groupId>javax.xml</groupId>
            <artifactId>jaxb-api</artifactId>
            <version>2.1</version>
        </dependency>
        <dependency>
            <groupId>com.sun.xml.bind</groupId>
            <artifactId>jaxb-impl</artifactId>
            <version>2.2.11</version>
        </dependency>
        <dependency>
            <groupId>javax.servlet</groupId>
            <artifactId>javax.servlet-api</artifactId>
            <version>3.0.1</version>
        </dependency>
    </dependencies>

專案結構為


介面——index.jsp

<%@ page contentType="text/html;charset=UTF-8" language="java" %>
<html>
<head>
    <title>使用pmml跨平臺部署機器學習模型Demo</title>
</head>
<body>
<h2>使用pmml跨平臺部署機器學習模型Demo——房價預測</h2>
<form name="form" method="post" action="/PredictServlet">
    <label>房子英尺數(整數):</label>
    <input type="text" name="feet" required>
    <button type="submit">預測房價</button>
</form>
<div>
    <label>預測價格為:</label>
    ${price}
</div>

</body>
</html>

Servlet類——PredictServlet.java

package servlet;

import service.PredictService;
import service.imp.PredictServiceImp;

import javax.servlet.ServletException;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;

@WebServlet("/PredictServlet")
public class PredictServlet extends HttpServlet {
    protected void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        PredictService predictService = new PredictServiceImp();

        String feet_str = request.getParameter("feet"); //獲取前端傳來的值
        int feet = Integer.parseInt(feet_str);

        double price = predictService.getPredictedPrice(feet); //預測

        //請求轉發,返回結果
        request.setAttribute("price", price);
        request.getRequestDispatcher("/index.jsp").forward(request, response);
    }

    protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        this.doPost(request, response);
    }
}


Service介面——PredictService.java

package service;

public interface PredictService {
    public double getPredictedPrice(int feet);
}


Service實現類——PredictServiceImp.java

package service.imp;

import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import service.PredictService;

import java.io.FileInputStream;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

public class PredictServiceImp implements PredictService {
    public double getPredictedPrice(int feet) {
        String model_path = "D:\\my_liner_model.pmml"; //pmml模型檔案存放路徑
        Evaluator model = loadModel(model_path); //載入模型
        Object r = predict(model, feet); //預測
        double result = Double.parseDouble(String.format("%.2f", r)); //格式化
        return result;
    }

    private static Evaluator loadModel(String model_path){
        PMML pmml = new PMML(); //定義PMML物件
        InputStream inputStream; //定義輸入流
        try {
            inputStream = new FileInputStream(model_path); //輸入流接到磁碟上的模型檔案
            pmml = org.jpmml.model.PMMLUtil.unmarshal(inputStream); //將輸入流解析為PMML物件
        }catch (Exception e){
            e.printStackTrace();
        }
        ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance(); //例項化一個模型構造工廠
        Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml); //將PMML物件構造為Evaluator模型物件

        return evaluator;
    }

    private static Object predict(Evaluator evaluator, int feet){
        Map<String, Integer> data = new HashMap<String, Integer>(); //定義測試資料Map,存入各元自變數
        data.put("square_feet", feet); //鍵"square_feet"為自變數的名稱,應與訓練資料中的自變數名稱一致

        List<InputField> inputFieldList = evaluator.getInputFields(); //得到模型各元自變數的屬性列表
        Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
        for (InputField inputField : inputFieldList) { //遍歷各元自變數的屬性列表
            FieldName inputFieldName = inputField.getName();
            Object rawValue = data.get(inputFieldName.getValue()); //取出該元變數的值
            FieldValue inputFieldValue = inputField.prepare(rawValue); //將值加入該元自變數屬性中
            arguments.put(inputFieldName, inputFieldValue); //變數名和變數值的對加入LinkedHashMap
        }
        Map<FieldName, ?> result = evaluator.evaluate(arguments); //進行預測
        List<TargetField> targetFieldList = evaluator.getTargetFields(); //得到模型各元因變數的屬性列表
        FieldName targetFieldName = targetFieldList.get(0).getName(); //第一元因變數名稱
        Object targetFieldValue = result.get(targetFieldName); //由因變數名稱得到值

        return targetFieldValue;
    }
}


三、執行測試

  將python中訓練得到的pmml模型檔案置於D盤根目錄下,將檔案中的xmlns=".../PMML-4_4"修改為xmlns=".../PMML-4_3"。


啟動執行,瀏覽器訪問http://localhost/,進入頁面


輸入房子英尺數,點選‘預測房價’按鈕,展示出預測價格


打包下載:
https://download.csdn.net/download/Albert201605/45648664


End.

相關文章