TensorFlow模型部署到伺服器---TensorFlow2.0

PRINT王哲發表於2021-08-01

前言

​ 當一個TensorFlow模型訓練出來的時候,為了投入到實際應用,所以就需要部署到伺服器上。由於我本次所做的專案是一個javaweb的影像識別專案。所有我就想去尋找一下java呼叫TensorFlow訓練模型的辦法。

image-20210801132207199

由於TensorFlow很久沒更新的緣故,網上的部落格大都是18/19年的,並且是基於TensorFlow1.0的,對於現在使用的TensorFlow2.0不太友好。

下面我簡述一下TensorFlow1.0時期的方法:

1.動態模型生成不便

需要將訓練的.h5模型轉換成.pb模型,並且需要自己定義.pb模型的輸入輸出引數。(pb模型是一種基於動態圖的模型)

pb的生成程式碼冗長、而且對初學者真滴不太友好

image-20210801132825083

相比之下.h5模型的生成程式碼就一行

image-20210801133250449

此外,這個生成pb模型的程式碼是否能照搬使用,還是一個問題,並且還可能報一些奇奇怪怪的錯誤。

2.maven導包不便

查閱資料發現java上的TensorFlow的jar包都是TensorFlow1.0的

image-20210801144751826

現狀:

image-20210801144856860

並且maven官網上的TensorFlow2.0的api已經改名成了tensorflow-core-api,並且網上相關方面的教程十分難找。由於網上都是匯入的1.0的包,自己匯入2.0的包之後,詳細的呼叫教程可以說是沒有。從上面也可以看出來TensorFlow對java的呼叫也不怎麼重視了。所以這又給學習的途中徒增了很多困難。

全新思路

思路一

用java直接呼叫訓練好的模型很困難,那麼我們想辦法讓java呼叫python指令碼,讓python指令碼去呼叫.h5模型會不會更簡單呢?

程式碼如下

package com.guard.service;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;

public class api_service {

    public String recognize(String path){
        //此處的path是圖片路徑
        Process proc;
        String res = null;
        try {
            System.out.println("接受到的引數"+path);
            String[] cmd = new String[] { "python", "E:\\machine_learning\\predict.py", path};
            proc = Runtime.getRuntime().exec(cmd);
            BufferedReader in = new BufferedReader(new InputStreamReader(proc.getInputStream()));
            String line = null;
            while ((line = in.readLine()) != null) {
                System.out.println(line);
                res = line;
            }
            in.close();
            proc.waitFor();
        } catch (IOException e) {
            e.printStackTrace();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println(res+">>>>>>>>>>>");
        return res;
    }
}

但是我們可以看出,這個其實是用java在win上跑了這樣一個指令

image-20210801183239294

雖然這個確實是一個好辦法,但是這個路徑引數需要事先知道伺服器上的路徑,並且在協作開發的時候,每個人的路徑和環境就不同,雖然該方法能用,但是我認為還不夠好。

思路二

我們可以直接用python的flask框架,直接生成一個api介面,就可以遠端直接呼叫TensorFlow訓練好的模型進行結果預測。

image-20210801184122679

image-20210801184226699

個人認為,這種方法相較於用java呼叫命令列,這種方法還是更加直觀的

並且flask僅僅需要加個@app.route的註解就能實現,可謂是十分方便

下面是模型呼叫程式碼

model.py

import glob
import sys
import os
import cv2
import numpy as np
import tensorflow as tf
import image_processing

def model_ues(path):
    # 縮放圖片大小為100*100
    w = 100
    h = 100

    # 測試影像的地址 (改為自己的)

    # path_test = "resource/test24.jpg"
    api_token = "fklasjfljasdlkfjlasjflasjfljhasdljflsdjflkjsadljfljsda"
    path_test = image_processing.download_img(path,api_token)

    # 建立儲存影像的空列表
    imgs = []
    img = cv2.imread(path_test)
    img = cv2.resize(img, (w, h))
    # 將每張經過處理的影像資料儲存在之前建立的imgs空列表當中
    imgs.append(img)
    imgs = np.asarray(imgs, np.float32)
    # print("shape of data:",imgs.shape)

    # 匯入模型
    model = tf.keras.models.load_model(r"resource/rice_0.93.h5")
    # 建立影像標籤列表
    rice_dict = {0: 'Rice blast', 1: 'Rice fleck',
             2: 'Rice koji disease', 3: 'Sheath blight'}

    # 將影像匯入模型進行預測
    prediction = model.predict_classes(imgs)
    # prediction = np.argmax(model.predict(imgs), axis=-1)


    # 繪製預測影像
    for i in range(np.size(prediction)):
        # 列印每張影像的預測結果
        print(rice_dict[prediction[i]])
    return rice_dict[prediction[0]]

為了實現圖片外連結受,下面是圖片下載指令碼

image_processing.py

# coding: utf8
import requests
import random

def download_img(img_url, api_token):
    print (img_url)
    header = {"Authorization": "Bearer " + api_token} # 設定http header,視情況加需要的條目,這裡的token是用來鑑權的一種方式
    r = requests.get(img_url, headers=header, stream=True)
    print(r.status_code) # 返回狀態碼
    file_img = 'resource/img.png'

    # file_img = 'resource/'
    print(file_img)
    if r.status_code == 200:
        open(file_img, 'wb').write(r.content) # 將內容寫入圖片
        print("done")
    del r

    return file_img
# if __name__ == '__main__':
#     # 下載要的圖片
#     img_url = "https://z3.ax1x.com/2021/07/27/W5l6Qe.png"
#     api_token = "fklasjfljasdlkfjlasjflasjfljhasdljflsdjflkjsadljfljsda"
#     download_img(img_url, api_token)

主程式指令碼

app.py

from flask import Flask,render_template, url_for, request, json,jsonify
import model
app = Flask(__name__)

#設定編碼
app.config['JSON_AS_ASCII'] = False

@app.route('/test')
def hello_world():

    return "hello world"

@app.route('/predict', methods=['GET', 'POST'])
def form_data():
    my_path = request.form['path']
    print(my_path)
    str = model.model_ues(my_path)
    print("http://127.0.0.1:5000/predict")
    return jsonify({'result':str,'msg':'200'})

if __name__ == '__main__':
    app.run()

資料解析

雖然我們能夠通過postman進行測試接受到回傳的結果,但是我們要怎麼用java實現呢??

1.使用postman生成大致程式碼框架(postman生成的程式碼可能不能直接執行)

image-20210801185332234

這裡我選用的是java-okhttp的方法,但其實使用Unirest寫出來的程式碼更加簡潔易懂。

public class Get_result {

    public  String getResult(String path) throws IOException {
//        String path = "https://i.loli.net/2021/07/29/badDNR2OCironUf.jpg";
        OkHttpClient client = new OkHttpClient().newBuilder()
                .build();
        MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded");
        RequestBody body = RequestBody.create(mediaType, "path="+path);
        Request request = new Request.Builder()
                .url("http://127.0.0.1:8000/predict")
                .method("POST", body)
                .addHeader("Content-Type", "application/x-www-form-urlencoded")
                .build();
        Response response = client.newCall(request).execute();
        String result = response.body().string();
        System.out.println(result);
            }
}

{
  "msg": "200",
  "result": "Rice fleck"
}

獲取到json資料之後,就需要對json資料進行解析

java上的解析原理是,先按照json編寫一個類,之後用Gson對接受到的資料按照這個類進行規範化

(這裡可以用GsonFormatPlus外掛來自動生成這個實體類)

//Rice_result.java---為該json的實體類
package com.guard.tool;

import lombok.Data;
import lombok.NoArgsConstructor;

@NoArgsConstructor
@Data
public class Rice_result {
    private String msg;
    private String result;

}

下面是資料解析程式碼(和上面的okhttp獲取json資料的程式碼連起來看)

//json資料解析
        Gson gson = new Gson();
        java.lang.reflect.Type type = new TypeToken<Rice_result>(){}.getType();
        Rice_result rice_result = gson.fromJson(result, type);
        System.out.println(rice_result);
        if("200".equals(rice_result.getMsg())){
//            System.out.println(rice_result.getResult());
            return Rice_result.convertdata(rice_result.getResult());
        }else {
//            System.out.println("獲取結果出錯!!");
            return "獲取結果出錯!!";
        }

這樣的話就可以進行json資料的解析了。

圖鏈製作

由於需要使用java傳送post請求給flask的預測埠,那麼就需要把本地上傳的資料做成圖鏈,把圖鏈作為資料傳給flask的預測埠,從而來接收結果。

由於前端js的知識大多遺忘,這裡就選用了用java來傳送一個post請求,獲得回傳的資訊。

這裡我使用的是sm.ms的圖床(該圖床無需登入,且速度快,算得上是一個好的選擇)

//sm.ms的使用方法,建議看官方文件
package com.guard.tool;

import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import okhttp3.*;

import java.io.File;
import java.io.IOException;


public class CloudUpload {

  public String toUrl(String path) throws IOException {

//    String file_path = "E:/machine_learning/test8.jpg";

    String file_path = path;
    OkHttpClient client = new OkHttpClient().newBuilder()
            .build();
    MediaType mediaType = MediaType.parse("multipart/form-data");
    RequestBody body = new MultipartBody.Builder().setType(MultipartBody.FORM)
            .addFormDataPart("smfile",file_path,
                    RequestBody.create(MediaType.parse("application/octet-stream"),
                            new File(file_path)))
            .addFormDataPart("format","json")
            .build();
    Request request = new Request.Builder()
            .url("https://sm.ms/api/v2/upload")
            .method("POST", body)
            .addHeader("Content-Type", "multipart/form-data")
            .addHeader("Authorization", "TlxzRSaVJj0o7HFZOd9sgdf4Jl60RA00")
   //這裡的user-agent和Cookie需要自己開啟網站,到網站的頁面去拿取
            .addHeader("user-agent","Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36")
            .addHeader("Cookie", "SMMSrememberme=42417%3A10e8e9cb5281082b493fdee73381aeb2dca0bd3d; PHPSESSID=1gjog2em3ogof23vrqi79vd41m; SM_FC=runWNk3mPIiL8mzl%2FrlEfzM940LRKjLm182cm2qDrm4%3D")
            .build();
    Response response = client.newCall(request).execute();
    String result = response.body().string();
    System.out.println(result);
//    String result = response.body().string();

    Gson gson = new Gson();
    java.lang.reflect.Type type = new TypeToken<Image_data>(){}.getType();
    Image_data imge_data = gson.fromJson(result, type);
    System.out.println(imge_data);
    if (imge_data.getSuccess()){
      System.out.println(imge_data.getData().getUrl());
      return imge_data.getData().getUrl();
    }
    else{
      System.out.println("圖片已經上傳過一次!!");
      System.out.println(imge_data.getImages());
      return imge_data.getImages();
    }
  }
}

回傳的json結果--這個就需要使用上面的外掛來進行處理

{
    "success": true,
    "code": "success",
    "message": "Upload success.",
    "data": {
        "file_id": 0,
        "width": 192,
        "height": 454,
        "filename": "test25.jpg",
        "storename": "xICPNzFsfth5uJk.png",
        "size": 124993,
        "path": "/2021/08/01/xICPNzFsfth5uJk.png",
        "hash": "2exIdQGvBru46RKMyNjg3DhCTO",
        "url": "https://i.loli.net/2021/08/01/xICPNzFsfth5uJk.png",
        "delete": "https://sm.ms/delete/2exIdQGvBru46RKMyNjg3DhCTO",
        "page": "https://sm.ms/image/xICPNzFsfth5uJk"
    },
    "RequestId": "9BFE9DEB-8370-44C8-A8AF-AAB2DB753A18"
}

總結

以上就是我這次在小組編寫<基於CNN影像分類的水稻病蟲害識別>這個專案中的收穫。在此記錄下學習路上踩過的一些坑和一些解決方法。

相關文章