python BGE 模型轉換為onnx給java呼叫

bonelee發表於2024-06-29

最近在做RAG,因為涉及embedding計算,用到了BAAI BGE小模型,但是模型是給python呼叫的,需要轉換為onnx格式給java使用。所以有了下面的探索:

python程式碼:

import torch
from transformers import AutoTokenizer, AutoModel
from FlagEmbedding import FlagModel

# 初始化模型
model_name_or_path = 'bge-base-zh-v1.5/'
flag_model = FlagModel(model_name_or_path)
flag_model2 = FlagModel(model_name_or_path)

# 設定模型為評估模式
flag_model.model.eval()

# 建立一個dummy輸入
dummy_input_text = "This is a sample text for embedding calculation."
embedding = flag_model2.encode(dummy_input_text)
print("embedding shape:", embedding.shape)

inputs = flag_model.tokenizer(dummy_input_text, return_tensors="pt", padding='max_length', truncation=True, max_length=128)

# 將輸入移動到模型的裝置
inputs = {k: v.to(flag_model.device) for k, v in inputs.items()}

# 匯出模型為ONNX格式
onnx_model_path = "flag_model.onnx"
torch.onnx.export(
    flag_model.model,
    (inputs['input_ids'], inputs['attention_mask']),
    onnx_model_path,
    input_names=['input_ids', 'attention_mask'],
    output_names=['output'],
    dynamic_axes={'input_ids': {0: 'batch_size'}, 'attention_mask': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)

print(f"Model has been converted to ONNX and saved to {onnx_model_path}")

import torch
import onnxruntime as ort
import numpy as np
from transformers import AutoTokenizer
from FlagEmbedding import FlagModel

# 初始化模型和tokenizer
model_name_or_path = 'bge-base-zh-v1.5/'
flag_model = FlagModel(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

# 設定模型為評估模式
flag_model.model.eval()

# ONNX模型推理
def onnx_inference(text, pooling_method='cls', normalize_embeddings=True):
    ort_session = ort.InferenceSession("flag_model.onnx")
    inputs = tokenizer(text, return_tensors="pt", padding='max_length', truncation=True, max_length=128)
    input_ids = inputs['input_ids'].cpu().numpy()
    attention_mask = inputs['attention_mask'].cpu().numpy()
    ort_inputs = {
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }
    ort_outs = ort_session.run(None, ort_inputs)
    last_hidden_state = ort_outs[0]

    # Apply pooling
    if pooling_method == 'cls':
        print("cls pooling method")
        embeddings = last_hidden_state[:, 0, :]
    elif pooling_method == 'mean':
        print("mean pooling mode")
        s = np.sum(last_hidden_state * np.expand_dims(attention_mask, axis=-1), axis=1)
        d = np.sum(attention_mask, axis=1, keepdims=True)
        embeddings = s / d

    # Normalize embeddings if required
    if normalize_embeddings:
        print("normalize embeddings")
        norm = np.linalg.norm(embeddings, axis=-1, keepdims=True)
        embeddings = embeddings / norm

    return embeddings


# 輸入文字
texts = [
    "This is a sample text for embedding calculation.",
    "Another example text to test the model.",
    "Yet another text to ensure consistency.",
    "Testing with different lengths and contents.",
    "Final text to verify the ONNX model accuracy.",
    "中文資料測試",
    "隨便說點什麼吧!反正也只是測試用!",
    "你好!jone!"
]
# 對比結果
for text in texts:
    # original_embedding = original_inference(text)
    original_embedding = flag_model2.encode(text).reshape(1, 768)
    onnx_embedding = onnx_inference(text) # .reshape(768,)
    print("shape compare:", original_embedding.shape, onnx_embedding.shape)
    difference = np.abs(original_embedding - onnx_embedding)
    max_difference = np.max(difference)
    print(f"Text: {text}")
    print(f"Max Difference: {max_difference}")
    print("-" * 50)


with open("D:\\data\\python_bge.txt", "w", encoding="utf-8") as f:
    for text in texts:
        cls_embedding = flag_model2.encode(text)
        f.write("[")
        f.write(", ".join(map(str, cls_embedding)))
        f.write("]\n")

  

python輸出:

D:\Python\Python312\python.exe D:\source\pythonProject\demo_onnx.py 
embedding shape: (768,)
Model has been converted to ONNX and saved to flag_model.onnx
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: This is a sample text for embedding calculation.
Max Difference: 8.940696716308594e-07
--------------------------------------------------
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: Another example text to test the model.
Max Difference: 5.364418029785156e-07
--------------------------------------------------
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: Yet another text to ensure consistency.
Max Difference: 2.384185791015625e-07
--------------------------------------------------
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: Testing with different lengths and contents.
Max Difference: 5.960464477539062e-07
--------------------------------------------------
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: Final text to verify the ONNX model accuracy.
Max Difference: 4.76837158203125e-07
--------------------------------------------------
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: 中文資料測試
Max Difference: 3.5762786865234375e-07
--------------------------------------------------
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: 隨便說點什麼吧!反正也只是測試用!
Max Difference: 5.364418029785156e-07
--------------------------------------------------
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: 你好!jone!
Max Difference: 2.1047890186309814e-07

  

Java呼叫BGE onnx程式碼:

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.onnxruntime.*;

import java.io.BufferedWriter;
import java.io.FileWriter;
import java.nio.file.Paths;
import java.util.Map;
import java.util.HashMap;
import java.nio.LongBuffer;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;

import java.io.IOException;


public class App2 {
    private static String poolingMethod = "cls";
    private static boolean normalizeEmbeddings = true;

    private static long[] padArray(long[] array, int length) {
        long[] paddedArray = new long[length];
        System.arraycopy(array, 0, paddedArray, 0, Math.min(array.length, length));
        return paddedArray;
    }

    private static float[] encode(OrtEnvironment env, OrtSession session, HuggingFaceTokenizer tokenizer, String text) throws OrtException {
        Encoding enc = tokenizer.encode(text);
        long[] inputIdsData = enc.getIds();
        long[] attentionMaskData = enc.getAttentionMask();


        int maxLength = 128;
        int batchSize = 1;

        long[] inputIdsShape = new long[]{batchSize, maxLength};
        long[] attentionMaskShape = new long[]{batchSize, maxLength};
        // 確保陣列長度為128
        inputIdsData = padArray(inputIdsData, maxLength);
        attentionMaskData = padArray(attentionMaskData, maxLength);

        OnnxTensor inputIdsTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(inputIdsData), inputIdsShape);
        OnnxTensor attentionMaskTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(attentionMaskData), attentionMaskShape);

        // 建立輸入的Map
        Map<String, OnnxTensor> inputs = new HashMap<>();
        inputs.put("input_ids", inputIdsTensor);
        inputs.put("attention_mask", attentionMaskTensor);

        // 執行推理
        OrtSession.Result result = session.run(inputs);

        // 獲取輸出
//        float[][][] output = (float[][][]) result.get(0).getValue();
//        System.out.println("Output shape: [" + output.length + ", " + output[0].length + ", " + output[0][0].length + "]");

        // 提取三維陣列
        float[][][] lastHiddenState = (float[][][]) result.get(0).getValue();

        float[] embeddings;
        if ("cls".equals(poolingMethod)) {
//            System.out.println("cls pooling method");
            embeddings = lastHiddenState[0][0];
        } else if ("mean".equals(poolingMethod)) {
//            System.out.println("mean pooling mode");
            int sequenceLength = lastHiddenState[0].length;
            int hiddenSize = lastHiddenState[0][0].length;
            float[] sum = new float[hiddenSize];
            int count = 0;

            for (int i = 0; i < sequenceLength; i++) {
                if (attentionMaskData[i] == 1) {
                    for (int j = 0; j < hiddenSize; j++) {
                        sum[j] += lastHiddenState[0][i][j];
                    }
                    count++;
                }
            }

            float[] mean = new float[hiddenSize];
            for (int j = 0; j < hiddenSize; j++) {
                mean[j] = sum[j] / count;
            }
            embeddings = mean;
        } else {
            throw new IllegalArgumentException("Unsupported pooling method: " + poolingMethod);
        }

        if (normalizeEmbeddings) {
//            System.out.println("normalize embeddings");
            float norm = 0;
            for (float v : embeddings) {
                norm += v * v;
            }
            norm = (float) Math.sqrt(norm);
            for (int i = 0; i < embeddings.length; i++) {
                embeddings[i] /= norm;
            }
        }

        // 釋放資源
        inputIdsTensor.close();
        attentionMaskTensor.close();

        return embeddings;
    }


    public static void main(String[] args) throws OrtException, IOException {
        // 載入ONNX模型
        String modelPath = "D:\\source\\pythonProject\\flag_model.onnx";// "flag_model.onnx";
        OrtEnvironment env = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
        OrtSession session = env.createSession(modelPath, opts);

        HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(Paths.get("D:\\source\\pythonProject\\onnx\\tokenizer.json"));

        String[] texts = {
                "This is a sample text for embedding calculation.",
                "Another example text to test the model.",
                "Yet another text to ensure consistency.",
                "Testing with different lengths and contents.",
                "Final text to verify the ONNX model accuracy.",
                "中文資料測試",
                "隨便說點什麼吧!反正也只是測試用!",
                "你好!jone!"
        };

        try (BufferedWriter writer = new BufferedWriter(new FileWriter("D:\\data\\java_bge.txt"))) {
            for (String text : texts) {
                float[] clsEmbedding = encode(env, session, tokenizer, text);

                // writer.write("Text: " + text + "\n");
//                writer.write("CLS Embedding shape: " + clsEmbedding.length + "\n");
                writer.write("[");
                for (int i = 0; i < clsEmbedding.length; i++) {
                    writer.write(String.valueOf(clsEmbedding[i]));
                    if (i < clsEmbedding.length - 1) {
                        writer.write(", ");
                    }
                }
                writer.write("]\n");
            }
        }

        session.close();
        env.close();
    }
}

  

最後,我們比較下二者的輸出差異:

def compare_java_python_result():
    import numpy as np
    import json

    # 函式:從檔案讀取嵌入向量
    def load_embeddings(filename):
        embeddings = []
        with open(filename, 'r') as f:
            for line in f:
                embedding = json.loads(line.strip())
                embeddings.append(embedding)
        return np.array(embeddings)

    # 載入兩個檔案的嵌入向量
    original_embeddings = load_embeddings(r"D:\data\java_bge.txt")
    onnx_embeddings = load_embeddings(r"D:\data\python_bge.txt")

    # 檢查兩個檔案的嵌入向量數量是否匹配
    if original_embeddings.shape != onnx_embeddings.shape:
        raise ValueError("The number of embeddings in the two files does not match.")

    # 比較每對嵌入向量並計算最大差異
    for i in range(original_embeddings.shape[0]):
        original_embedding = original_embeddings[i]
        onnx_embedding = onnx_embeddings[i]

        # 計算差異
        difference = np.abs(original_embedding - onnx_embedding)
        max_difference = np.max(difference)

        # 輸出結果
        print(f"Embedding {i + 1}:")
        print(f"Max Difference: {max_difference}")
        print("-" * 50)


compare_java_python_result()

  

輸出結果:

--------------------------------------------------
Embedding 1:
Max Difference: 7.499999999938112e-07
--------------------------------------------------
Embedding 2:
Max Difference: 5.600000000383076e-07
--------------------------------------------------
Embedding 3:
Max Difference: 1.8700000000565487e-07
--------------------------------------------------
Embedding 4:
Max Difference: 4.500000000406956e-07
--------------------------------------------------
Embedding 5:
Max Difference: 6.000000000172534e-07
--------------------------------------------------
Embedding 6:
Max Difference: 5.699999999775329e-07
--------------------------------------------------
Embedding 7:
Max Difference: 5.700000000885552e-07
--------------------------------------------------
Embedding 8:
Max Difference: 2.1039999999975662e-07
--------------------------------------------------

  

又是一個充滿收穫的一天!!!歐耶!!!

相關文章