use of ai.onnxruntime.OrtException in project tribuo by oracle.
the class ONNXExternalModel method createOnnxModel.
/**
* Creates an {@code ONNXExternalModel} by loading the model from disk.
*
* @param factory The output factory to use.
* @param featureMapping The feature mapping between Tribuo names and ONNX integer ids.
* @param outputMapping The output mapping between Tribuo outputs and ONNX integer ids.
* @param featureTransformer The transformation function for the features.
* @param outputTransformer The transformation function for the outputs.
* @param opts The session options for the ONNX model.
* @param path The model path.
* @param inputName The name of the input node.
* @param <T> The type of the output.
* @return An ONNXExternalModel ready to score new inputs.
* @throws OrtException If the onnx-runtime native library call failed.
*/
public static <T extends Output<T>> ONNXExternalModel<T> createOnnxModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T, Integer> outputMapping, ExampleTransformer featureTransformer, OutputTransformer<T> outputTransformer, OrtSession.SessionOptions opts, Path path, String inputName) throws OrtException {
try {
byte[] modelArray = Files.readAllBytes(path);
URL provenanceLocation = path.toUri().toURL();
ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet());
ImmutableOutputInfo<T> outputInfo = ExternalModel.createOutputInfo(factory, outputMapping);
OffsetDateTime now = OffsetDateTime.now();
ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(provenanceLocation);
DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data", factory, false, featureMapping.size(), outputMapping.size());
HashMap<String, Provenance> runProvenance = new HashMap<>();
runProvenance.put("input-name", new StringProvenance("input-name", inputName));
try (OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession session = env.createSession(modelArray)) {
OnnxModelMetadata metadata = session.getMetadata();
runProvenance.put("model-producer", new StringProvenance("model-producer", metadata.getProducerName()));
runProvenance.put("model-domain", new StringProvenance("model-domain", metadata.getDomain()));
runProvenance.put("model-description", new StringProvenance("model-description", metadata.getDescription()));
runProvenance.put("model-graphname", new StringProvenance("model-graphname", metadata.getGraphName()));
runProvenance.put("model-version", new LongProvenance("model-version", metadata.getVersion()));
for (Map.Entry<String, String> e : metadata.getCustomMetadata().entrySet()) {
if (!e.getKey().equals(ONNXExportable.PROVENANCE_METADATA_FIELD)) {
String keyName = "model-metadata-" + e.getKey();
runProvenance.put(keyName, new StringProvenance(keyName, e.getValue()));
}
}
} catch (OrtException e) {
throw new IllegalArgumentException("Failed to load model and read metadata from path " + path, e);
}
ModelProvenance provenance = new ModelProvenance(ONNXExternalModel.class.getName(), now, datasetProvenance, trainerProvenance, runProvenance);
return new ONNXExternalModel<>("external-model", provenance, featureMap, outputInfo, featureMapping, modelArray, opts, inputName, featureTransformer, outputTransformer);
} catch (IOException e) {
throw new IllegalArgumentException("Unable to load model from path " + path, e);
}
}
use of ai.onnxruntime.OrtException in project MasterProject by Alexsogge.
the class HandWashDetection method initModel.
public void initModel() {
SharedPreferences configs = context.getSharedPreferences(context.getString(R.string.configs), Context.MODE_PRIVATE);
String currentModelName = configs.getString(context.getApplicationContext().getString(R.string.val_current_tf_model), "base_model.tflite");
String[] currentModelNameParts = currentModelName.split("\\.(?=[^\\.]+$)");
String modelExtension = currentModelNameParts[1];
Log.d("pred", "try load model " + currentModelName + "with extension " + modelExtension);
if (modelExtension.equals("tflite")) {
useONNXModel = false;
try {
loadTFModel();
} catch (IOException e) {
e.printStackTrace();
makeToast(context.getString(R.string.toast_couldnt_load_tf));
}
} else {
try {
loadORTModel();
useONNXModel = true;
} catch (OrtException | IOException e) {
e.printStackTrace();
}
}
initialized = true;
makeToast(this.context.getString(R.string.toast_use_dl_tf) + loadedModelName);
labelList = new ArrayList<String>();
labelList.add("0");
labelList.add("1");
}
use of ai.onnxruntime.OrtException in project MasterProject by Alexsogge.
the class HandWashDetection method runSample.
private void runSample() throws OrtException {
Log.d("pred", "Run model");
float[][] sourceArray = new float[1][900];
OnnxTensor tensor = OnnxTensor.createTensor(env, sourceArray);
try {
OrtSession.Result output = session.run(Collections.singletonMap("input", tensor));
Log.d("pred", "Reslut size:" + output.size());
Log.d("pred", "Reslut info:" + output.get(0).getInfo().toString());
float[][] values = (float[][]) output.get(0).getValue();
Log.d("pred", "Reslut values:" + values[0][0] + " | " + values[0][1]);
} catch (Exception e) {
e.printStackTrace();
}
}
use of ai.onnxruntime.OrtException in project MasterProject by Alexsogge.
the class HandWashDetection method RunONNXInference.
public float[][] RunONNXInference(float[][][] window) {
// create buffer as required
float[][][] frameBuffer = new float[1][frameSize][6];
int bufferPos = 0;
// we have to fill the frameBuffer with dummy values for these sensors
for (int x = 0; x < frameSize; x++) {
int axesIndex = 0;
for (int i = 0; i < requiredSensors.length; i++) {
int activeSensorIndex = getActiveSensorIndexOfType(requiredSensors[i]);
// if index == -1 we don't have actual values -> insert dummy else value from buffer
for (int axes = 0; axes < requiredSensorsDimensions[i]; axes++) {
if (activeSensorIndex == -1) {
frameBuffer[0][x][axesIndex] = 0;
} else {
frameBuffer[0][x][axesIndex] = window[activeSensorIndex][x][axes];
}
bufferPos++;
axesIndex++;
}
}
}
float[][] output = new float[1][2];
try {
OnnxTensor tensor = OnnxTensor.createTensor(env, frameBuffer);
OrtSession.Result result = session.run(Collections.singletonMap("input", tensor));
output = (float[][]) result.get(0).getValue();
} catch (OrtException e) {
e.printStackTrace();
}
return output;
}
use of ai.onnxruntime.OrtException in project djl by deepjavalibrary.
the class OrtModel method load.
/**
* {@inheritDoc}
*/
@Override
public void load(Path modelPath, String prefix, Map<String, ?> options) throws IOException, MalformedModelException {
setModelDir(modelPath);
if (block != null) {
throw new UnsupportedOperationException("ONNX Runtime does not support dynamic blocks");
}
Path modelFile = findModelFile(prefix);
if (modelFile == null) {
modelFile = findModelFile(modelDir.toFile().getName());
if (modelFile == null) {
throw new FileNotFoundException(".onnx file not found in: " + modelPath);
}
}
try {
SessionOptions ortOptions = getSessionOptions(options);
OrtSession session = env.createSession(modelFile.toString(), ortOptions);
block = new OrtSymbolBlock(session, (OrtNDManager) manager);
} catch (OrtException e) {
throw new MalformedModelException("ONNX Model cannot be loaded", e);
}
}
Aggregations