Search in sources :

Example 1 with LongProvenance

use of com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance in project tribuo by oracle.

the class ONNXExternalModel method createOnnxModel.

/**
 * Creates an {@code ONNXExternalModel} by loading the model from disk.
 *
 * @param factory            The output factory to use.
 * @param featureMapping     The feature mapping between Tribuo names and ONNX integer ids.
 * @param outputMapping      The output mapping between Tribuo outputs and ONNX integer ids.
 * @param featureTransformer The transformation function for the features.
 * @param outputTransformer  The transformation function for the outputs.
 * @param opts               The session options for the ONNX model.
 * @param path               The model path.
 * @param inputName          The name of the input node.
 * @param <T>                The type of the output.
 * @return An ONNXExternalModel ready to score new inputs.
 * @throws OrtException If the onnx-runtime native library call failed.
 */
public static <T extends Output<T>> ONNXExternalModel<T> createOnnxModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T, Integer> outputMapping, ExampleTransformer featureTransformer, OutputTransformer<T> outputTransformer, OrtSession.SessionOptions opts, Path path, String inputName) throws OrtException {
    try {
        byte[] modelArray = Files.readAllBytes(path);
        URL provenanceLocation = path.toUri().toURL();
        ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet());
        ImmutableOutputInfo<T> outputInfo = ExternalModel.createOutputInfo(factory, outputMapping);
        OffsetDateTime now = OffsetDateTime.now();
        ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(provenanceLocation);
        DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data", factory, false, featureMapping.size(), outputMapping.size());
        HashMap<String, Provenance> runProvenance = new HashMap<>();
        runProvenance.put("input-name", new StringProvenance("input-name", inputName));
        try (OrtEnvironment env = OrtEnvironment.getEnvironment();
            OrtSession session = env.createSession(modelArray)) {
            OnnxModelMetadata metadata = session.getMetadata();
            runProvenance.put("model-producer", new StringProvenance("model-producer", metadata.getProducerName()));
            runProvenance.put("model-domain", new StringProvenance("model-domain", metadata.getDomain()));
            runProvenance.put("model-description", new StringProvenance("model-description", metadata.getDescription()));
            runProvenance.put("model-graphname", new StringProvenance("model-graphname", metadata.getGraphName()));
            runProvenance.put("model-version", new LongProvenance("model-version", metadata.getVersion()));
            for (Map.Entry<String, String> e : metadata.getCustomMetadata().entrySet()) {
                if (!e.getKey().equals(ONNXExportable.PROVENANCE_METADATA_FIELD)) {
                    String keyName = "model-metadata-" + e.getKey();
                    runProvenance.put(keyName, new StringProvenance(keyName, e.getValue()));
                }
            }
        } catch (OrtException e) {
            throw new IllegalArgumentException("Failed to load model and read metadata from path " + path, e);
        }
        ModelProvenance provenance = new ModelProvenance(ONNXExternalModel.class.getName(), now, datasetProvenance, trainerProvenance, runProvenance);
        return new ONNXExternalModel<>("external-model", provenance, featureMap, outputInfo, featureMapping, modelArray, opts, inputName, featureTransformer, outputTransformer);
    } catch (IOException e) {
        throw new IllegalArgumentException("Unable to load model from path " + path, e);
    }
}
Also used : ModelProvenance(org.tribuo.provenance.ModelProvenance) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) ExternalTrainerProvenance(org.tribuo.interop.ExternalTrainerProvenance) Provenance(com.oracle.labs.mlrg.olcut.provenance.Provenance) StringProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance) LongProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance) ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) ExternalTrainerProvenance(org.tribuo.interop.ExternalTrainerProvenance) HashMap(java.util.HashMap) OrtException(ai.onnxruntime.OrtException) URL(java.net.URL) OrtEnvironment(ai.onnxruntime.OrtEnvironment) OnnxModelMetadata(ai.onnxruntime.OnnxModelMetadata) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) LongProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance) ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) StringProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) IOException(java.io.IOException) OrtSession(ai.onnxruntime.OrtSession) OffsetDateTime(java.time.OffsetDateTime) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) HashMap(java.util.HashMap) Map(java.util.Map)

Aggregations

OnnxModelMetadata (ai.onnxruntime.OnnxModelMetadata)1 OrtEnvironment (ai.onnxruntime.OrtEnvironment)1 OrtException (ai.onnxruntime.OrtException)1 OrtSession (ai.onnxruntime.OrtSession)1 Provenance (com.oracle.labs.mlrg.olcut.provenance.Provenance)1 LongProvenance (com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance)1 StringProvenance (com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance)1 IOException (java.io.IOException)1 URL (java.net.URL)1 OffsetDateTime (java.time.OffsetDateTime)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)1 ExternalDatasetProvenance (org.tribuo.interop.ExternalDatasetProvenance)1 ExternalTrainerProvenance (org.tribuo.interop.ExternalTrainerProvenance)1 DatasetProvenance (org.tribuo.provenance.DatasetProvenance)1 ModelProvenance (org.tribuo.provenance.ModelProvenance)1