djl載入模型

KIKIcoo發表於2024-04-01

愚蠢的我以為載入pytorch模型還需要將模型結構在Java中寫出來,因此各種寫模型,還問gpt怎麼把並行的結構用djl寫出來,然後又問如何讓djl載入模型,結果,我是個傻子。

人家早就把東西都弄好了,幾行程式碼就搞定了。

 Criteria<NDList, NDList>criteria = Criteria.builder()
        .setTypes(NDList.class, NDList.class)
        .optModelPath(modelDir)
        .optEngine("PyTorch")
        .build();
    ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria);
    Predictor<NDList, NDList> predictor = model.newPredictor();

這裡面,NDList是多個NDArray的封裝,類似於封裝了多個Tensor,optEngine是設定你的模型是從pytorch還是從TensorFlow來的。

optModelPath是指出pytorch生成的scriptTensor--》pt檔案或者壓縮包--》zip檔案。

實際上,pt檔案裡面已經有模型的結構了,壓根不需要自己寫,直接呼叫OK了。

ModelZoo也是直接載入criteria就把模型整好了。

最後結果,我是個傻子。



相關文章