Search in sources :

Example 1 with ExternalTrainerProvenance

use of org.tribuo.interop.ExternalTrainerProvenance in project tribuo by oracle.

the class ONNXExternalModel method createOnnxModel.

/**
 * Creates an {@code ONNXExternalModel} by loading the model from disk.
 *
 * @param factory            The output factory to use.
 * @param featureMapping     The feature mapping between Tribuo names and ONNX integer ids.
 * @param outputMapping      The output mapping between Tribuo outputs and ONNX integer ids.
 * @param featureTransformer The transformation function for the features.
 * @param outputTransformer  The transformation function for the outputs.
 * @param opts               The session options for the ONNX model.
 * @param path               The model path.
 * @param inputName          The name of the input node.
 * @param <T>                The type of the output.
 * @return An ONNXExternalModel ready to score new inputs.
 * @throws OrtException If the onnx-runtime native library call failed.
 */
public static <T extends Output<T>> ONNXExternalModel<T> createOnnxModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T, Integer> outputMapping, ExampleTransformer featureTransformer, OutputTransformer<T> outputTransformer, OrtSession.SessionOptions opts, Path path, String inputName) throws OrtException {
    try {
        byte[] modelArray = Files.readAllBytes(path);
        URL provenanceLocation = path.toUri().toURL();
        ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet());
        ImmutableOutputInfo<T> outputInfo = ExternalModel.createOutputInfo(factory, outputMapping);
        OffsetDateTime now = OffsetDateTime.now();
        ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(provenanceLocation);
        DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data", factory, false, featureMapping.size(), outputMapping.size());
        HashMap<String, Provenance> runProvenance = new HashMap<>();
        runProvenance.put("input-name", new StringProvenance("input-name", inputName));
        try (OrtEnvironment env = OrtEnvironment.getEnvironment();
            OrtSession session = env.createSession(modelArray)) {
            OnnxModelMetadata metadata = session.getMetadata();
            runProvenance.put("model-producer", new StringProvenance("model-producer", metadata.getProducerName()));
            runProvenance.put("model-domain", new StringProvenance("model-domain", metadata.getDomain()));
            runProvenance.put("model-description", new StringProvenance("model-description", metadata.getDescription()));
            runProvenance.put("model-graphname", new StringProvenance("model-graphname", metadata.getGraphName()));
            runProvenance.put("model-version", new LongProvenance("model-version", metadata.getVersion()));
            for (Map.Entry<String, String> e : metadata.getCustomMetadata().entrySet()) {
                if (!e.getKey().equals(ONNXExportable.PROVENANCE_METADATA_FIELD)) {
                    String keyName = "model-metadata-" + e.getKey();
                    runProvenance.put(keyName, new StringProvenance(keyName, e.getValue()));
                }
            }
        } catch (OrtException e) {
            throw new IllegalArgumentException("Failed to load model and read metadata from path " + path, e);
        }
        ModelProvenance provenance = new ModelProvenance(ONNXExternalModel.class.getName(), now, datasetProvenance, trainerProvenance, runProvenance);
        return new ONNXExternalModel<>("external-model", provenance, featureMap, outputInfo, featureMapping, modelArray, opts, inputName, featureTransformer, outputTransformer);
    } catch (IOException e) {
        throw new IllegalArgumentException("Unable to load model from path " + path, e);
    }
}
Also used : ModelProvenance(org.tribuo.provenance.ModelProvenance) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) ExternalTrainerProvenance(org.tribuo.interop.ExternalTrainerProvenance) Provenance(com.oracle.labs.mlrg.olcut.provenance.Provenance) StringProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance) LongProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance) ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) ExternalTrainerProvenance(org.tribuo.interop.ExternalTrainerProvenance) HashMap(java.util.HashMap) OrtException(ai.onnxruntime.OrtException) URL(java.net.URL) OrtEnvironment(ai.onnxruntime.OrtEnvironment) OnnxModelMetadata(ai.onnxruntime.OnnxModelMetadata) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) LongProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance) ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) StringProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) IOException(java.io.IOException) OrtSession(ai.onnxruntime.OrtSession) OffsetDateTime(java.time.OffsetDateTime) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) HashMap(java.util.HashMap) Map(java.util.Map)

Example 2 with ExternalTrainerProvenance

use of org.tribuo.interop.ExternalTrainerProvenance in project tribuo by oracle.

the class TensorFlowSavedModelExternalModel method createTensorflowModel.

/**
 * Creates a TensorflowSavedModelExternalModel by loading in a {@code SavedModelBundle}.
 * <p>
 * Throws {@link IllegalArgumentException} if the model bundle could not be loaded.
 * @param factory The output factory.
 * @param featureMapping The feature mapping between Tribuo's names and the TF integer ids.
 * @param outputMapping The output mapping between Tribuo's names and the TF integer ids.
 * @param outputName The name of the output tensor.
 * @param featureConverter The feature transformation function.
 * @param outputConverter The output transformation function.
 * @param bundleDirectory The path to load the saved model bundle from.
 * @param <T> The type of the output.
 * @return The TF model wrapped in a Tribuo {@link ExternalModel}.
 */
public static <T extends Output<T>> TensorFlowSavedModelExternalModel<T> createTensorflowModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T, Integer> outputMapping, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter, String bundleDirectory) {
    try {
        Path path = Paths.get(bundleDirectory);
        URL provenanceLocation = path.toUri().toURL();
        ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet());
        ImmutableOutputInfo<T> outputInfo = ExternalModel.createOutputInfo(factory, outputMapping);
        OffsetDateTime now = OffsetDateTime.now();
        ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(provenanceLocation);
        DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data", factory, false, featureMapping.size(), outputMapping.size());
        ModelProvenance provenance = new ModelProvenance(TensorFlowSavedModelExternalModel.class.getName(), now, datasetProvenance, trainerProvenance);
        return new TensorFlowSavedModelExternalModel<>("tf-saved-model-bundle", provenance, featureMap, outputInfo, featureMapping, bundleDirectory, outputName, featureConverter, outputConverter);
    } catch (IOException | TensorFlowException e) {
        throw new IllegalArgumentException("Unable to load model from path " + bundleDirectory, e);
    }
}
Also used : Path(java.nio.file.Path) ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) ExternalTrainerProvenance(org.tribuo.interop.ExternalTrainerProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) IOException(java.io.IOException) URL(java.net.URL) OffsetDateTime(java.time.OffsetDateTime) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) TensorFlowException(org.tensorflow.exceptions.TensorFlowException)

Example 3 with ExternalTrainerProvenance

use of org.tribuo.interop.ExternalTrainerProvenance in project tribuo by oracle.

the class OCIModel method createOCIModel.

/**
 * Creates an {@code OCIModel} by wrapping an OCI DS Model Deployment endpoint.
 * <p>
 * Uses the endpointURL as the value to hash for the trainer provenance.
 *
 * @param factory         The output factory to use.
 * @param featureMapping  The feature mapping between Tribuo names and model integer ids.
 * @param outputMapping   The output mapping between Tribuo outputs and model integer ids.
 * @param configFile      The OCI configuration file, if null use the default file.
 * @param profileName     The profile name in the OCI configuration file, if null uses the default profile.
 * @param endpointURL     The endpoint URL.
 * @param outputConverter The converter for the specified output type.
 * @param <T> The output type.
 * @return An OCIModel ready to score new inputs.
 */
public static <T extends Output<T>> OCIModel<T> createOCIModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T, Integer> outputMapping, Path configFile, String profileName, String endpointURL, OCIOutputConverter<T> outputConverter) {
    try {
        ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet());
        ImmutableOutputInfo<T> outputInfo = ExternalModel.createOutputInfo(factory, outputMapping);
        OffsetDateTime now = OffsetDateTime.now();
        ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance((endpointURL).getBytes(StandardCharsets.UTF_8));
        DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data", factory, false, featureMapping.size(), outputMapping.size());
        String[] endpoint = endpointURL.split("/");
        String domain = "https://" + endpoint[2] + "/";
        String modelDeploymentId = endpoint[3];
        HashMap<String, Provenance> runProvenance = new HashMap<>();
        runProvenance.put("configFile", new FileProvenance("configFile", configFile));
        runProvenance.put("endpointURL", new StringProvenance("endpointURL", endpointURL));
        runProvenance.put("modelDeploymentId", new StringProvenance("modelDeploymentId", modelDeploymentId));
        ModelProvenance provenance = new ModelProvenance(OCIModel.class.getName(), now, datasetProvenance, trainerProvenance, runProvenance);
        return new OCIModel<T>("oci-ds-model", provenance, featureMap, outputInfo, featureMapping, configFile, profileName, domain, modelDeploymentId, outputConverter);
    } catch (IOException e) {
        throw new IllegalArgumentException("Unable to load configuration from path " + configFile, e);
    }
}
Also used : ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) ExternalTrainerProvenance(org.tribuo.interop.ExternalTrainerProvenance) FileProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.FileProvenance) Provenance(com.oracle.labs.mlrg.olcut.provenance.Provenance) StringProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance) ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) ExternalTrainerProvenance(org.tribuo.interop.ExternalTrainerProvenance) HashMap(java.util.HashMap) StringProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) IOException(java.io.IOException) FileProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.FileProvenance) OffsetDateTime(java.time.OffsetDateTime) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap)

Example 4 with ExternalTrainerProvenance

use of org.tribuo.interop.ExternalTrainerProvenance in project tribuo by oracle.

the class TensorFlowFrozenExternalModel method createTensorflowModel.

/**
 * Creates a TensorflowFrozenExternalModel by loading in a frozen graph.
 * @param factory The output factory.
 * @param featureMapping The feature mapping between Tribuo's names and the TF integer ids.
 * @param outputMapping The output mapping between Tribuo's names and the TF integer ids.
 * @param inputName The name of the input placeholder.
 * @param outputName The name of the output tensor.
 * @param featureConverter The feature transformation function.
 * @param outputConverter The output transformation function.
 * @param filename The filename to load the graph from.
 * @param <T> The type of the output.
 * @return The TF model wrapped in a Tribuo ExternalModel.
 */
public static <T extends Output<T>> TensorFlowFrozenExternalModel<T> createTensorflowModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T, Integer> outputMapping, String inputName, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter, String filename) {
    try {
        Path path = Paths.get(filename);
        byte[] model = Files.readAllBytes(path);
        Graph graph = new Graph();
        graph.importGraphDef(GraphDef.parseFrom(model));
        URL provenanceLocation = path.toUri().toURL();
        ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet());
        ImmutableOutputInfo<T> outputInfo = ExternalModel.createOutputInfo(factory, outputMapping);
        OffsetDateTime now = OffsetDateTime.now();
        ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(provenanceLocation);
        DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data", factory, false, featureMapping.size(), outputMapping.size());
        ModelProvenance provenance = new ModelProvenance(TensorFlowFrozenExternalModel.class.getName(), now, datasetProvenance, trainerProvenance);
        return new TensorFlowFrozenExternalModel<>("tf-frozen-graph", provenance, featureMap, outputInfo, featureMapping, graph, inputName, outputName, featureConverter, outputConverter);
    } catch (IOException e) {
        throw new IllegalArgumentException("Unable to load model from path " + filename, e);
    }
}
Also used : Path(java.nio.file.Path) ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) ExternalTrainerProvenance(org.tribuo.interop.ExternalTrainerProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) IOException(java.io.IOException) URL(java.net.URL) Graph(org.tensorflow.Graph) OffsetDateTime(java.time.OffsetDateTime) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap)

Example 5 with ExternalTrainerProvenance

use of org.tribuo.interop.ExternalTrainerProvenance in project tribuo by oracle.

the class XGBoostExternalModel method createXGBoostModel.

/**
 * Creates an {@code XGBoostExternalModel} from the supplied model on disk.
 * @param factory The output factory to use.
 * @param featureMapping The feature mapping between Tribuo names and XGBoost integer ids.
 * @param outputMapping The output mapping between Tribuo outputs and XGBoost integer ids.
 * @param outputFunc The XGBoostOutputConverter function for the output type.
 * @param path The path to the model on disk.
 * @param <T> The type of the output.
 * @return An XGBoostExternalModel ready to score new inputs.
 */
public static <T extends Output<T>> XGBoostExternalModel<T> createXGBoostModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T, Integer> outputMapping, XGBoostOutputConverter<T> outputFunc, Path path) {
    try {
        Booster model = XGBoost.loadModel(Files.newInputStream(path));
        ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(path.toUri().toURL());
        return createXGBoostModel(factory, featureMapping, outputMapping, outputFunc, model, trainerProvenance, Collections.emptyMap());
    } catch (XGBoostError | IOException e) {
        throw new IllegalArgumentException("Unable to load model from path " + path, e);
    }
}
Also used : ExternalTrainerProvenance(org.tribuo.interop.ExternalTrainerProvenance) Booster(ml.dmlc.xgboost4j.java.Booster) XGBoostError(ml.dmlc.xgboost4j.java.XGBoostError) IOException(java.io.IOException)

Aggregations

ExternalTrainerProvenance (org.tribuo.interop.ExternalTrainerProvenance)6 IOException (java.io.IOException)5 OffsetDateTime (java.time.OffsetDateTime)4 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)4 ExternalDatasetProvenance (org.tribuo.interop.ExternalDatasetProvenance)4 DatasetProvenance (org.tribuo.provenance.DatasetProvenance)4 ModelProvenance (org.tribuo.provenance.ModelProvenance)4 URL (java.net.URL)3 Provenance (com.oracle.labs.mlrg.olcut.provenance.Provenance)2 StringProvenance (com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance)2 Path (java.nio.file.Path)2 HashMap (java.util.HashMap)2 Booster (ml.dmlc.xgboost4j.java.Booster)2 XGBoostError (ml.dmlc.xgboost4j.java.XGBoostError)2 OnnxModelMetadata (ai.onnxruntime.OnnxModelMetadata)1 OrtEnvironment (ai.onnxruntime.OrtEnvironment)1 OrtException (ai.onnxruntime.OrtException)1 OrtSession (ai.onnxruntime.OrtSession)1 FileProvenance (com.oracle.labs.mlrg.olcut.provenance.primitives.FileProvenance)1 LongProvenance (com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance)1