Search in sources :

Example 1 with OrtException

use of ai.onnxruntime.OrtException in project tribuo by oracle.

the class ONNXExternalModel method createOnnxModel.

/**
 * Creates an {@code ONNXExternalModel} by loading the model from disk.
 *
 * @param factory            The output factory to use.
 * @param featureMapping     The feature mapping between Tribuo names and ONNX integer ids.
 * @param outputMapping      The output mapping between Tribuo outputs and ONNX integer ids.
 * @param featureTransformer The transformation function for the features.
 * @param outputTransformer  The transformation function for the outputs.
 * @param opts               The session options for the ONNX model.
 * @param path               The model path.
 * @param inputName          The name of the input node.
 * @param <T>                The type of the output.
 * @return An ONNXExternalModel ready to score new inputs.
 * @throws OrtException If the onnx-runtime native library call failed.
 */
public static <T extends Output<T>> ONNXExternalModel<T> createOnnxModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T, Integer> outputMapping, ExampleTransformer featureTransformer, OutputTransformer<T> outputTransformer, OrtSession.SessionOptions opts, Path path, String inputName) throws OrtException {
    try {
        byte[] modelArray = Files.readAllBytes(path);
        URL provenanceLocation = path.toUri().toURL();
        ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet());
        ImmutableOutputInfo<T> outputInfo = ExternalModel.createOutputInfo(factory, outputMapping);
        OffsetDateTime now = OffsetDateTime.now();
        ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(provenanceLocation);
        DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data", factory, false, featureMapping.size(), outputMapping.size());
        HashMap<String, Provenance> runProvenance = new HashMap<>();
        runProvenance.put("input-name", new StringProvenance("input-name", inputName));
        try (OrtEnvironment env = OrtEnvironment.getEnvironment();
            OrtSession session = env.createSession(modelArray)) {
            OnnxModelMetadata metadata = session.getMetadata();
            runProvenance.put("model-producer", new StringProvenance("model-producer", metadata.getProducerName()));
            runProvenance.put("model-domain", new StringProvenance("model-domain", metadata.getDomain()));
            runProvenance.put("model-description", new StringProvenance("model-description", metadata.getDescription()));
            runProvenance.put("model-graphname", new StringProvenance("model-graphname", metadata.getGraphName()));
            runProvenance.put("model-version", new LongProvenance("model-version", metadata.getVersion()));
            for (Map.Entry<String, String> e : metadata.getCustomMetadata().entrySet()) {
                if (!e.getKey().equals(ONNXExportable.PROVENANCE_METADATA_FIELD)) {
                    String keyName = "model-metadata-" + e.getKey();
                    runProvenance.put(keyName, new StringProvenance(keyName, e.getValue()));
                }
            }
        } catch (OrtException e) {
            throw new IllegalArgumentException("Failed to load model and read metadata from path " + path, e);
        }
        ModelProvenance provenance = new ModelProvenance(ONNXExternalModel.class.getName(), now, datasetProvenance, trainerProvenance, runProvenance);
        return new ONNXExternalModel<>("external-model", provenance, featureMap, outputInfo, featureMapping, modelArray, opts, inputName, featureTransformer, outputTransformer);
    } catch (IOException e) {
        throw new IllegalArgumentException("Unable to load model from path " + path, e);
    }
}
Also used : ModelProvenance(org.tribuo.provenance.ModelProvenance) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) ExternalTrainerProvenance(org.tribuo.interop.ExternalTrainerProvenance) Provenance(com.oracle.labs.mlrg.olcut.provenance.Provenance) StringProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance) LongProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance) ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) ExternalTrainerProvenance(org.tribuo.interop.ExternalTrainerProvenance) HashMap(java.util.HashMap) OrtException(ai.onnxruntime.OrtException) URL(java.net.URL) OrtEnvironment(ai.onnxruntime.OrtEnvironment) OnnxModelMetadata(ai.onnxruntime.OnnxModelMetadata) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) LongProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance) ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) StringProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) IOException(java.io.IOException) OrtSession(ai.onnxruntime.OrtSession) OffsetDateTime(java.time.OffsetDateTime) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) HashMap(java.util.HashMap) Map(java.util.Map)

Example 2 with OrtException

use of ai.onnxruntime.OrtException in project MasterProject by Alexsogge.

the class HandWashDetection method initModel.

public void initModel() {
    SharedPreferences configs = context.getSharedPreferences(context.getString(R.string.configs), Context.MODE_PRIVATE);
    String currentModelName = configs.getString(context.getApplicationContext().getString(R.string.val_current_tf_model), "base_model.tflite");
    String[] currentModelNameParts = currentModelName.split("\\.(?=[^\\.]+$)");
    String modelExtension = currentModelNameParts[1];
    Log.d("pred", "try load model " + currentModelName + "with extension " + modelExtension);
    if (modelExtension.equals("tflite")) {
        useONNXModel = false;
        try {
            loadTFModel();
        } catch (IOException e) {
            e.printStackTrace();
            makeToast(context.getString(R.string.toast_couldnt_load_tf));
        }
    } else {
        try {
            loadORTModel();
            useONNXModel = true;
        } catch (OrtException | IOException e) {
            e.printStackTrace();
        }
    }
    initialized = true;
    makeToast(this.context.getString(R.string.toast_use_dl_tf) + loadedModelName);
    labelList = new ArrayList<String>();
    labelList.add("0");
    labelList.add("1");
}
Also used : SharedPreferences(android.content.SharedPreferences) IOException(java.io.IOException) OrtException(ai.onnxruntime.OrtException)

Example 3 with OrtException

use of ai.onnxruntime.OrtException in project MasterProject by Alexsogge.

the class HandWashDetection method runSample.

private void runSample() throws OrtException {
    Log.d("pred", "Run model");
    float[][] sourceArray = new float[1][900];
    OnnxTensor tensor = OnnxTensor.createTensor(env, sourceArray);
    try {
        OrtSession.Result output = session.run(Collections.singletonMap("input", tensor));
        Log.d("pred", "Reslut size:" + output.size());
        Log.d("pred", "Reslut info:" + output.get(0).getInfo().toString());
        float[][] values = (float[][]) output.get(0).getValue();
        Log.d("pred", "Reslut values:" + values[0][0] + " | " + values[0][1]);
    } catch (Exception e) {
        e.printStackTrace();
    }
}
Also used : OrtSession(ai.onnxruntime.OrtSession) OnnxTensor(ai.onnxruntime.OnnxTensor) OrtException(ai.onnxruntime.OrtException) JSONException(org.json.JSONException) IOException(java.io.IOException) FileNotFoundException(java.io.FileNotFoundException)

Example 4 with OrtException

use of ai.onnxruntime.OrtException in project MasterProject by Alexsogge.

the class HandWashDetection method RunONNXInference.

public float[][] RunONNXInference(float[][][] window) {
    // create buffer as required
    float[][][] frameBuffer = new float[1][frameSize][6];
    int bufferPos = 0;
    // we have to fill the frameBuffer with dummy values for these sensors
    for (int x = 0; x < frameSize; x++) {
        int axesIndex = 0;
        for (int i = 0; i < requiredSensors.length; i++) {
            int activeSensorIndex = getActiveSensorIndexOfType(requiredSensors[i]);
            // if index == -1 we don't have actual values -> insert dummy else value from buffer
            for (int axes = 0; axes < requiredSensorsDimensions[i]; axes++) {
                if (activeSensorIndex == -1) {
                    frameBuffer[0][x][axesIndex] = 0;
                } else {
                    frameBuffer[0][x][axesIndex] = window[activeSensorIndex][x][axes];
                }
                bufferPos++;
                axesIndex++;
            }
        }
    }
    float[][] output = new float[1][2];
    try {
        OnnxTensor tensor = OnnxTensor.createTensor(env, frameBuffer);
        OrtSession.Result result = session.run(Collections.singletonMap("input", tensor));
        output = (float[][]) result.get(0).getValue();
    } catch (OrtException e) {
        e.printStackTrace();
    }
    return output;
}
Also used : OrtSession(ai.onnxruntime.OrtSession) OnnxTensor(ai.onnxruntime.OnnxTensor) OrtException(ai.onnxruntime.OrtException)

Example 5 with OrtException

use of ai.onnxruntime.OrtException in project djl by deepjavalibrary.

the class OrtModel method load.

/**
 * {@inheritDoc}
 */
@Override
public void load(Path modelPath, String prefix, Map<String, ?> options) throws IOException, MalformedModelException {
    setModelDir(modelPath);
    if (block != null) {
        throw new UnsupportedOperationException("ONNX Runtime does not support dynamic blocks");
    }
    Path modelFile = findModelFile(prefix);
    if (modelFile == null) {
        modelFile = findModelFile(modelDir.toFile().getName());
        if (modelFile == null) {
            throw new FileNotFoundException(".onnx file not found in: " + modelPath);
        }
    }
    try {
        SessionOptions ortOptions = getSessionOptions(options);
        OrtSession session = env.createSession(modelFile.toString(), ortOptions);
        block = new OrtSymbolBlock(session, (OrtNDManager) manager);
    } catch (OrtException e) {
        throw new MalformedModelException("ONNX Model cannot be loaded", e);
    }
}
Also used : Path(java.nio.file.Path) OrtSession(ai.onnxruntime.OrtSession) SessionOptions(ai.onnxruntime.OrtSession.SessionOptions) FileNotFoundException(java.io.FileNotFoundException) MalformedModelException(ai.djl.MalformedModelException) OrtException(ai.onnxruntime.OrtException)

Aggregations

OrtException (ai.onnxruntime.OrtException)19 OnnxTensor (ai.onnxruntime.OnnxTensor)10 OrtSession (ai.onnxruntime.OrtSession)9 IOException (java.io.IOException)5 HashMap (java.util.HashMap)5 ArrayList (java.util.ArrayList)4 EngineException (ai.djl.engine.EngineException)3 OnnxValue (ai.onnxruntime.OnnxValue)3 OrtEnvironment (ai.onnxruntime.OrtEnvironment)3 MalformedModelException (ai.djl.MalformedModelException)2 OnnxModelMetadata (ai.onnxruntime.OnnxModelMetadata)2 SessionOptions (ai.onnxruntime.OrtSession.SessionOptions)2 Provenance (com.oracle.labs.mlrg.olcut.provenance.Provenance)2 LongProvenance (com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance)2 StringProvenance (com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance)2 FileNotFoundException (java.io.FileNotFoundException)2 Path (java.nio.file.Path)2 Map (java.util.Map)2 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)2 Test (org.junit.jupiter.api.Test)2