在機器學習用於產品的時候,我們經常會遇到跨平臺的問題。比如我們用Python基於一系列的機器學習庫訓練了一個模型,但是有時候其他的產品和專案想把這個模型整合進去,但是這些產品很多隻支援某些特定的生產環境比如Java,為了上一個機器學習模型去大動干戈修改環境配置很不划算,此時我們就可以考慮用預測模型標記語言(Predictive Model Markup Language,以下簡稱PMML)來實現跨平臺的機器學習模型部署了。
1. PMML概述
PMML是資料探勘的一種通用的規範,它用統一的XML格式來描述我們生成的機器學習模型。這樣無論你的模型是sklearn,R還是Spark MLlib生成的,我們都可以將其轉化為標準的XML格式來儲存。當我們需要將這個PMML的模型用於部署的時候,可以使用目標環境的解析PMML模型的庫來載入模型,並做預測。
可以看出,要使用PMML,需要兩步的工作,第一塊是將離線訓練得到的模型轉化為PMML模型檔案,第二塊是將PMML模型檔案載入線上預測環境,進行預測。這兩塊都需要相關的庫支援。
2. PMML模型的生成和載入相關類庫
PMML模型的生成相關的庫需要看我們使用的離線訓練庫。如果我們使用的是sklearn,那麼可以使用sklearn2pmml這個python庫來做模型檔案的生成,這個庫安裝很簡單,使用"pip install sklearn2pmml"即可,相關的使用我們後面會有一個demo。如果使用的是Spark MLlib, 這個庫有一些模型已經自帶了儲存PMML模型的方法,可惜並不全。如果是R,則需要安裝包"XML"和“PMML”。此外,JAVA庫JPMML可以用來生成R,SparkMLlib,xgBoost,Sklearn的模型對應的PMML檔案。github地址是:https://github.com/jpmml/jpmml。
載入PMML模型需要目標環境支援PMML載入的庫,如果是JAVA,則可以用JPMML來載入PMML模型檔案。相關的使用我們後面會有一個demo。
3. PMML模型生成和載入示例
下面我們給一個示例,使用sklearn生成一個決策樹模型,用sklearn2pmml生成模型檔案,用JPMML載入模型檔案,並做預測。
完整程式碼參見我的github:https://github.com/ljpzzz/machinelearning/blob/master/model-in-product/sklearn-jpmml
首先是用用sklearn生成一個決策樹模型,由於我們是需要儲存PMML檔案,所以最好把模型先放到一個Pipeline陣列裡面。這個陣列裡面除了我們的決策樹模型以外,還可以有歸一化,降維等預處理操作,這裡作為一個示例,我們Pipeline陣列裡面只有決策樹模型。程式碼如下:
import numpy as np import matplotlib.pyplot as plt %matplotlib inline import pandas as pd from sklearn import tree from sklearn2pmml.pipeline import PMMLPipeline from sklearn2pmml import sklearn2pmml import os os.environ["PATH"] += os.pathsep + 'C:/Program Files/Java/jdk1.8.0_171/bin' X=[[1,2,3,1],[2,4,1,5],[7,8,3,6],[4,8,4,7],[2,5,6,9]] y=[0,1,0,2,1] pipeline = PMMLPipeline([("classifier", tree.DecisionTreeClassifier(random_state=9))]); pipeline.fit(X,y) sklearn2pmml(pipeline, ".\demo.pmml", with_repr = True)
上面這段程式碼做了一個非常簡單的決策樹分類模型,只有5個訓練樣本,特徵有4個,輸出類別有3個。實際應用時,我們需要將模型調參完畢後才將其放入PMMLPipeline進行儲存。執行程式碼後,我們在當前目錄會得到一個PMML的XML檔案,可以直接開啟看,內容大概如下:
<?xml version="1.0" encoding="UTF-8" standalone="yes"?> <PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3"> <Header> <Application name="JPMML-SkLearn" version="1.5.3"/> <Timestamp>2018-06-24T05:47:17Z</Timestamp> </Header> <MiningBuildTask> <Extension>PMMLPipeline(steps=[('classifier', DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None, max_features=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, presort=False, random_state=9, splitter='best'))])</Extension> </MiningBuildTask> <DataDictionary> <DataField name="y" optype="categorical" dataType="integer"> <Value value="0"/> <Value value="1"/> <Value value="2"/> </DataField> <DataField name="x3" optype="continuous" dataType="float"/> <DataField name="x4" optype="continuous" dataType="float"/> </DataDictionary> <TransformationDictionary> <DerivedField name="double(x3)" optype="continuous" dataType="double"> <FieldRef field="x3"/> </DerivedField> <DerivedField name="double(x4)" optype="continuous" dataType="double"> <FieldRef field="x4"/> </DerivedField> </TransformationDictionary> <TreeModel functionName="classification" missingValueStrategy="nullPrediction" splitCharacteristic="multiSplit"> <MiningSchema> <MiningField name="y" usageType="target"/> <MiningField name="x3"/> <MiningField name="x4"/> </MiningSchema> <Output> <OutputField name="probability(0)" optype="continuous" dataType="double" feature="probability" value="0"/> <OutputField name="probability(1)" optype="continuous" dataType="double" feature="probability" value="1"/> <OutputField name="probability(2)" optype="continuous" dataType="double" feature="probability" value="2"/> </Output> <Node> <True/> <Node> <SimplePredicate field="double(x3)" operator="lessOrEqual" value="3.5"/> <Node score="1" recordCount="1.0"> <SimplePredicate field="double(x3)" operator="lessOrEqual" value="2.0"/> <ScoreDistribution value="0" recordCount="0.0"/> <ScoreDistribution value="1" recordCount="1.0"/> <ScoreDistribution value="2" recordCount="0.0"/> </Node> <Node score="0" recordCount="2.0"> <True/> <ScoreDistribution value="0" recordCount="2.0"/> <ScoreDistribution value="1" recordCount="0.0"/> <ScoreDistribution value="2" recordCount="0.0"/> </Node> </Node> <Node score="2" recordCount="1.0"> <SimplePredicate field="double(x4)" operator="lessOrEqual" value="8.0"/> <ScoreDistribution value="0" recordCount="0.0"/> <ScoreDistribution value="1" recordCount="0.0"/> <ScoreDistribution value="2" recordCount="1.0"/> </Node> <Node score="1" recordCount="1.0"> <True/> <ScoreDistribution value="0" recordCount="0.0"/> <ScoreDistribution value="1" recordCount="1.0"/> <ScoreDistribution value="2" recordCount="0.0"/> </Node> </Node> </TreeModel> </PMML>
可以看到裡面就是決策樹模型的樹結構節點的各個引數,以及輸入值。我們的輸入被定義為x1-x4,輸出定義為y。
有了PMML模型檔案,我們就可以寫JAVA程式碼來讀取載入這個模型並做預測了。
我們建立一個Maven或者gradle工程,加入JPMML的依賴,這裡給出maven在pom.xml的依賴,gradle的結構是類似的。
<dependency> <groupId>org.jpmml</groupId> <artifactId>pmml-evaluator</artifactId> <version>1.4.1</version> </dependency> <dependency> <groupId>org.jpmml</groupId> <artifactId>pmml-evaluator-extension</artifactId> <version>1.4.1</version> </dependency>
接著就是讀取模型檔案並預測的程式碼了,具體程式碼如下:
import org.dmg.pmml.FieldName; import org.dmg.pmml.PMML; import org.jpmml.evaluator.*; import org.xml.sax.SAXException; import javax.xml.bind.JAXBException; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; /** * Created by 劉建平Pinard on 2018/6/24. */ public class PMMLDemo { private Evaluator loadPmml(){ PMML pmml = new PMML(); InputStream inputStream = null; try { inputStream = new FileInputStream("D:/demo.pmml"); } catch (IOException e) { e.printStackTrace(); } if(inputStream == null){ return null; } InputStream is = inputStream; try { pmml = org.jpmml.model.PMMLUtil.unmarshal(is); } catch (SAXException e1) { e1.printStackTrace(); } catch (JAXBException e1) { e1.printStackTrace(); }finally { //關閉輸入流 try { is.close(); } catch (IOException e) { e.printStackTrace(); } } ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance(); Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml); pmml = null; return evaluator; } private int predict(Evaluator evaluator,int a, int b, int c, int d) { Map<String, Integer> data = new HashMap<String, Integer>(); data.put("x1", a); data.put("x2", b); data.put("x3", c); data.put("x4", d); List<InputField> inputFields = evaluator.getInputFields(); //過模型的原始特徵,從畫像中獲取資料,作為模型輸入 Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>(); for (InputField inputField : inputFields) { FieldName inputFieldName = inputField.getName(); Object rawValue = data.get(inputFieldName.getValue()); FieldValue inputFieldValue = inputField.prepare(rawValue); arguments.put(inputFieldName, inputFieldValue); } Map<FieldName, ?> results = evaluator.evaluate(arguments); List<TargetField> targetFields = evaluator.getTargetFields(); TargetField targetField = targetFields.get(0); FieldName targetFieldName = targetField.getName(); Object targetFieldValue = results.get(targetFieldName); System.out.println("target: " + targetFieldName.getValue() + " value: " + targetFieldValue); int primitiveValue = -1; if (targetFieldValue instanceof Computable) { Computable computable = (Computable) targetFieldValue; primitiveValue = (Integer)computable.getResult(); } System.out.println(a + " " + b + " " + c + " " + d + ":" + primitiveValue); return primitiveValue; } public static void main(String args[]){ PMMLDemo demo = new PMMLDemo(); Evaluator model = demo.loadPmml(); demo.predict(model,1,8,99,1); demo.predict(model,111,89,9,11); } }
程式碼裡有兩個函式,第一個loadPmml是載入模型的,第二個predict是讀取預測樣本並返回預測值的。我的程式碼執行結果如下:
target: y value: {result=2, probability_entries=[0=0.0, 1=0.0, 2=1.0], entityId=5, confidence_entries=[]}
1 8 99 1:2
target: y value: {result=1, probability_entries=[0=0.0, 1=1.0, 2=0.0], entityId=6, confidence_entries=[]}
111 89 9 11:1
也就是樣本(1,8,99,1)被預測為類別2,而(111,89,9,11)被預測為類別1。
以上就是PMML生成和載入的一個示例,使用起來其實門檻並不高,也很簡單。
4. PMML總結與思考
PMML的確是跨平臺的利器,但是是不是就沒有缺點呢?肯定是有的!
第一個就是PMML為了滿足跨平臺,犧牲了很多平臺獨有的優化,所以很多時候我們用演算法庫自己的儲存模型的API得到的模型檔案,要比生成的PMML模型檔案小很多。同時PMML檔案載入速度也比演算法庫自己獨有格式的模型檔案載入慢很多。
第二個就是PMML載入得到的模型和演算法庫自己獨有的模型相比,預測會有一點點的偏差,當然這個偏差並不大。比如某一個樣本,用sklearn的決策樹模型預測為類別1,但是如果我們把這個決策樹落盤為一個PMML檔案,並用JAVA載入後,繼續預測剛才這個樣本,有較小的概率出現預測的結果不為類別1.
第三個就是對於超大模型,比如大規模的整合學習模型,比如xgboost, 隨機森林,或者tensorflow,生成的PMML檔案很容易得到幾個G,甚至上T,這時使用PMML檔案載入預測速度會非常慢,此時推薦為模型建立一個專有的環境,就沒有必要去考慮跨平臺了。
此外,對於TensorFlow,不推薦使用PMML的方式來跨平臺。可能的方法一是TensorFlow serving,自己搭建預測服務,但是會稍有些複雜。另一個方法就是將模型儲存為TensorFlow的模型檔案,並用TensorFlow獨有的JAVA庫載入來做預測。
我們在下一篇會討論用python+tensorflow訓練儲存模型,並用tensorflow的JAVA庫載入做預測的方法和例項。
(歡迎轉載,轉載請註明出處。歡迎溝通交流: liujianping-ok@163.com)