use of org.tribuo.interop.ExternalTrainerProvenance in project tribuo by oracle.
the class ONNXExternalModel method createOnnxModel.
/**
* Creates an {@code ONNXExternalModel} by loading the model from disk.
*
* @param factory The output factory to use.
* @param featureMapping The feature mapping between Tribuo names and ONNX integer ids.
* @param outputMapping The output mapping between Tribuo outputs and ONNX integer ids.
* @param featureTransformer The transformation function for the features.
* @param outputTransformer The transformation function for the outputs.
* @param opts The session options for the ONNX model.
* @param path The model path.
* @param inputName The name of the input node.
* @param <T> The type of the output.
* @return An ONNXExternalModel ready to score new inputs.
* @throws OrtException If the onnx-runtime native library call failed.
*/
public static <T extends Output<T>> ONNXExternalModel<T> createOnnxModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T, Integer> outputMapping, ExampleTransformer featureTransformer, OutputTransformer<T> outputTransformer, OrtSession.SessionOptions opts, Path path, String inputName) throws OrtException {
try {
byte[] modelArray = Files.readAllBytes(path);
URL provenanceLocation = path.toUri().toURL();
ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet());
ImmutableOutputInfo<T> outputInfo = ExternalModel.createOutputInfo(factory, outputMapping);
OffsetDateTime now = OffsetDateTime.now();
ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(provenanceLocation);
DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data", factory, false, featureMapping.size(), outputMapping.size());
HashMap<String, Provenance> runProvenance = new HashMap<>();
runProvenance.put("input-name", new StringProvenance("input-name", inputName));
try (OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession session = env.createSession(modelArray)) {
OnnxModelMetadata metadata = session.getMetadata();
runProvenance.put("model-producer", new StringProvenance("model-producer", metadata.getProducerName()));
runProvenance.put("model-domain", new StringProvenance("model-domain", metadata.getDomain()));
runProvenance.put("model-description", new StringProvenance("model-description", metadata.getDescription()));
runProvenance.put("model-graphname", new StringProvenance("model-graphname", metadata.getGraphName()));
runProvenance.put("model-version", new LongProvenance("model-version", metadata.getVersion()));
for (Map.Entry<String, String> e : metadata.getCustomMetadata().entrySet()) {
if (!e.getKey().equals(ONNXExportable.PROVENANCE_METADATA_FIELD)) {
String keyName = "model-metadata-" + e.getKey();
runProvenance.put(keyName, new StringProvenance(keyName, e.getValue()));
}
}
} catch (OrtException e) {
throw new IllegalArgumentException("Failed to load model and read metadata from path " + path, e);
}
ModelProvenance provenance = new ModelProvenance(ONNXExternalModel.class.getName(), now, datasetProvenance, trainerProvenance, runProvenance);
return new ONNXExternalModel<>("external-model", provenance, featureMap, outputInfo, featureMapping, modelArray, opts, inputName, featureTransformer, outputTransformer);
} catch (IOException e) {
throw new IllegalArgumentException("Unable to load model from path " + path, e);
}
}
use of org.tribuo.interop.ExternalTrainerProvenance in project tribuo by oracle.
the class TensorFlowSavedModelExternalModel method createTensorflowModel.
/**
* Creates a TensorflowSavedModelExternalModel by loading in a {@code SavedModelBundle}.
* <p>
* Throws {@link IllegalArgumentException} if the model bundle could not be loaded.
* @param factory The output factory.
* @param featureMapping The feature mapping between Tribuo's names and the TF integer ids.
* @param outputMapping The output mapping between Tribuo's names and the TF integer ids.
* @param outputName The name of the output tensor.
* @param featureConverter The feature transformation function.
* @param outputConverter The output transformation function.
* @param bundleDirectory The path to load the saved model bundle from.
* @param <T> The type of the output.
* @return The TF model wrapped in a Tribuo {@link ExternalModel}.
*/
public static <T extends Output<T>> TensorFlowSavedModelExternalModel<T> createTensorflowModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T, Integer> outputMapping, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter, String bundleDirectory) {
try {
Path path = Paths.get(bundleDirectory);
URL provenanceLocation = path.toUri().toURL();
ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet());
ImmutableOutputInfo<T> outputInfo = ExternalModel.createOutputInfo(factory, outputMapping);
OffsetDateTime now = OffsetDateTime.now();
ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(provenanceLocation);
DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data", factory, false, featureMapping.size(), outputMapping.size());
ModelProvenance provenance = new ModelProvenance(TensorFlowSavedModelExternalModel.class.getName(), now, datasetProvenance, trainerProvenance);
return new TensorFlowSavedModelExternalModel<>("tf-saved-model-bundle", provenance, featureMap, outputInfo, featureMapping, bundleDirectory, outputName, featureConverter, outputConverter);
} catch (IOException | TensorFlowException e) {
throw new IllegalArgumentException("Unable to load model from path " + bundleDirectory, e);
}
}
use of org.tribuo.interop.ExternalTrainerProvenance in project tribuo by oracle.
the class OCIModel method createOCIModel.
/**
* Creates an {@code OCIModel} by wrapping an OCI DS Model Deployment endpoint.
* <p>
* Uses the endpointURL as the value to hash for the trainer provenance.
*
* @param factory The output factory to use.
* @param featureMapping The feature mapping between Tribuo names and model integer ids.
* @param outputMapping The output mapping between Tribuo outputs and model integer ids.
* @param configFile The OCI configuration file, if null use the default file.
* @param profileName The profile name in the OCI configuration file, if null uses the default profile.
* @param endpointURL The endpoint URL.
* @param outputConverter The converter for the specified output type.
* @param <T> The output type.
* @return An OCIModel ready to score new inputs.
*/
public static <T extends Output<T>> OCIModel<T> createOCIModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T, Integer> outputMapping, Path configFile, String profileName, String endpointURL, OCIOutputConverter<T> outputConverter) {
try {
ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet());
ImmutableOutputInfo<T> outputInfo = ExternalModel.createOutputInfo(factory, outputMapping);
OffsetDateTime now = OffsetDateTime.now();
ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance((endpointURL).getBytes(StandardCharsets.UTF_8));
DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data", factory, false, featureMapping.size(), outputMapping.size());
String[] endpoint = endpointURL.split("/");
String domain = "https://" + endpoint[2] + "/";
String modelDeploymentId = endpoint[3];
HashMap<String, Provenance> runProvenance = new HashMap<>();
runProvenance.put("configFile", new FileProvenance("configFile", configFile));
runProvenance.put("endpointURL", new StringProvenance("endpointURL", endpointURL));
runProvenance.put("modelDeploymentId", new StringProvenance("modelDeploymentId", modelDeploymentId));
ModelProvenance provenance = new ModelProvenance(OCIModel.class.getName(), now, datasetProvenance, trainerProvenance, runProvenance);
return new OCIModel<T>("oci-ds-model", provenance, featureMap, outputInfo, featureMapping, configFile, profileName, domain, modelDeploymentId, outputConverter);
} catch (IOException e) {
throw new IllegalArgumentException("Unable to load configuration from path " + configFile, e);
}
}
use of org.tribuo.interop.ExternalTrainerProvenance in project tribuo by oracle.
the class TensorFlowFrozenExternalModel method createTensorflowModel.
/**
* Creates a TensorflowFrozenExternalModel by loading in a frozen graph.
* @param factory The output factory.
* @param featureMapping The feature mapping between Tribuo's names and the TF integer ids.
* @param outputMapping The output mapping between Tribuo's names and the TF integer ids.
* @param inputName The name of the input placeholder.
* @param outputName The name of the output tensor.
* @param featureConverter The feature transformation function.
* @param outputConverter The output transformation function.
* @param filename The filename to load the graph from.
* @param <T> The type of the output.
* @return The TF model wrapped in a Tribuo ExternalModel.
*/
public static <T extends Output<T>> TensorFlowFrozenExternalModel<T> createTensorflowModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T, Integer> outputMapping, String inputName, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter, String filename) {
try {
Path path = Paths.get(filename);
byte[] model = Files.readAllBytes(path);
Graph graph = new Graph();
graph.importGraphDef(GraphDef.parseFrom(model));
URL provenanceLocation = path.toUri().toURL();
ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet());
ImmutableOutputInfo<T> outputInfo = ExternalModel.createOutputInfo(factory, outputMapping);
OffsetDateTime now = OffsetDateTime.now();
ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(provenanceLocation);
DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data", factory, false, featureMapping.size(), outputMapping.size());
ModelProvenance provenance = new ModelProvenance(TensorFlowFrozenExternalModel.class.getName(), now, datasetProvenance, trainerProvenance);
return new TensorFlowFrozenExternalModel<>("tf-frozen-graph", provenance, featureMap, outputInfo, featureMapping, graph, inputName, outputName, featureConverter, outputConverter);
} catch (IOException e) {
throw new IllegalArgumentException("Unable to load model from path " + filename, e);
}
}
use of org.tribuo.interop.ExternalTrainerProvenance in project tribuo by oracle.
the class XGBoostExternalModel method createXGBoostModel.
/**
* Creates an {@code XGBoostExternalModel} from the supplied model on disk.
* @param factory The output factory to use.
* @param featureMapping The feature mapping between Tribuo names and XGBoost integer ids.
* @param outputMapping The output mapping between Tribuo outputs and XGBoost integer ids.
* @param outputFunc The XGBoostOutputConverter function for the output type.
* @param path The path to the model on disk.
* @param <T> The type of the output.
* @return An XGBoostExternalModel ready to score new inputs.
*/
public static <T extends Output<T>> XGBoostExternalModel<T> createXGBoostModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T, Integer> outputMapping, XGBoostOutputConverter<T> outputFunc, Path path) {
try {
Booster model = XGBoost.loadModel(Files.newInputStream(path));
ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(path.toUri().toURL());
return createXGBoostModel(factory, featureMapping, outputMapping, outputFunc, model, trainerProvenance, Collections.emptyMap());
} catch (XGBoostError | IOException e) {
throw new IllegalArgumentException("Unable to load model from path " + path, e);
}
}
Aggregations