use of com.oracle.labs.mlrg.olcut.provenance.Provenance in project tribuo by oracle.
the class StripProvenance method cleanEnsembleProvenance.
/**
* Creates a new ensemble provenance with the requested information removed.
* @param old The old ensemble provenance.
* @param memberProvenance The new member provenances.
* @param provenanceHash The old ensemble provenance hash.
* @param opt The program options.
* @return The new ensemble provenance with the requested fields removed.
*/
private static EnsembleModelProvenance cleanEnsembleProvenance(EnsembleModelProvenance old, ListProvenance<ModelProvenance> memberProvenance, String provenanceHash, StripProvenanceOptions opt) {
// Dataset provenance
DatasetProvenance datasetProvenance;
if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(DATASET)) {
datasetProvenance = new EmptyDatasetProvenance();
} else {
datasetProvenance = old.getDatasetProvenance();
}
// Trainer provenance
TrainerProvenance trainerProvenance;
if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(TRAINER)) {
trainerProvenance = new EmptyTrainerProvenance();
} else {
trainerProvenance = old.getTrainerProvenance();
}
// Instance provenance
OffsetDateTime time;
Map<String, Provenance> instanceProvenance;
if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(INSTANCE)) {
instanceProvenance = new HashMap<>();
time = OffsetDateTime.MIN;
} else {
instanceProvenance = new HashMap<>(old.getInstanceProvenance().getMap());
time = old.getTrainingTime();
}
if (opt.storeHash) {
logger.info("Writing provenance hash into instance map.");
instanceProvenance.put("original-provenance-hash", new HashProvenance(opt.hashType, "original-provenance-hash", provenanceHash));
}
return new EnsembleModelProvenance(old.getClassName(), time, datasetProvenance, trainerProvenance, instanceProvenance, memberProvenance);
}
use of com.oracle.labs.mlrg.olcut.provenance.Provenance in project tribuo by oracle.
the class CSVLoaderTest method testLoadNoHeader.
@Test
public void testLoadNoHeader() throws IOException {
URL path = CSVLoader.class.getResource("/org/tribuo/data/csv/test.csv");
CSVLoader<MockOutput> loader = new CSVLoader<>(new MockOutputFactory());
//
// Currently, passing a the header into loader.load when the CSV has a header row will cause an error. This is
// because CSVIterator does not skip the first line in this case.
// TODO do we want this behavior?
String[] header = new String[] { "A", "B", "C", "D", "RESPONSE" };
assertThrows(NumberFormatException.class, () -> loader.load(Paths.get(path.toURI()), "RESPONSE", header));
assertThrows(NumberFormatException.class, () -> loader.load(Paths.get(path.toURI()), Collections.singleton("RESPONSE"), header));
//
// Test behavior when CSV file does not have a header row and the user instead supplies the header.
URL noheader = CSVLoader.class.getResource("/org/tribuo/data/csv/test-noheader.csv");
DataSource<MockOutput> source = loader.loadDataSource(noheader, "RESPONSE", header);
// Check that the source persisted the headers in the provenance
CSVDataSource.CSVDataSourceProvenance prov = (CSVDataSource.CSVDataSourceProvenance) source.getProvenance();
Provenance headerProv = prov.getConfiguredParameters().get("headers");
assertTrue(headerProv instanceof ListProvenance);
@SuppressWarnings("unchecked") ListProvenance<StringProvenance> listProv = (ListProvenance<StringProvenance>) headerProv;
assertEquals(header.length, listProv.getList().size());
assertEquals(Arrays.asList(header), listProv.getList().stream().map(StringProvenance::getValue).collect(Collectors.toList()));
// Check the data loaded correctly.
checkDataTestCsv(source);
checkDataTestCsv(loader.loadDataSource(noheader, Collections.singleton("RESPONSE"), header));
}
use of com.oracle.labs.mlrg.olcut.provenance.Provenance in project tribuo by oracle.
the class ONNXExternalModel method createOnnxModel.
/**
* Creates an {@code ONNXExternalModel} by loading the model from disk.
*
* @param factory The output factory to use.
* @param featureMapping The feature mapping between Tribuo names and ONNX integer ids.
* @param outputMapping The output mapping between Tribuo outputs and ONNX integer ids.
* @param featureTransformer The transformation function for the features.
* @param outputTransformer The transformation function for the outputs.
* @param opts The session options for the ONNX model.
* @param path The model path.
* @param inputName The name of the input node.
* @param <T> The type of the output.
* @return An ONNXExternalModel ready to score new inputs.
* @throws OrtException If the onnx-runtime native library call failed.
*/
public static <T extends Output<T>> ONNXExternalModel<T> createOnnxModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T, Integer> outputMapping, ExampleTransformer featureTransformer, OutputTransformer<T> outputTransformer, OrtSession.SessionOptions opts, Path path, String inputName) throws OrtException {
try {
byte[] modelArray = Files.readAllBytes(path);
URL provenanceLocation = path.toUri().toURL();
ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet());
ImmutableOutputInfo<T> outputInfo = ExternalModel.createOutputInfo(factory, outputMapping);
OffsetDateTime now = OffsetDateTime.now();
ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(provenanceLocation);
DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data", factory, false, featureMapping.size(), outputMapping.size());
HashMap<String, Provenance> runProvenance = new HashMap<>();
runProvenance.put("input-name", new StringProvenance("input-name", inputName));
try (OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession session = env.createSession(modelArray)) {
OnnxModelMetadata metadata = session.getMetadata();
runProvenance.put("model-producer", new StringProvenance("model-producer", metadata.getProducerName()));
runProvenance.put("model-domain", new StringProvenance("model-domain", metadata.getDomain()));
runProvenance.put("model-description", new StringProvenance("model-description", metadata.getDescription()));
runProvenance.put("model-graphname", new StringProvenance("model-graphname", metadata.getGraphName()));
runProvenance.put("model-version", new LongProvenance("model-version", metadata.getVersion()));
for (Map.Entry<String, String> e : metadata.getCustomMetadata().entrySet()) {
if (!e.getKey().equals(ONNXExportable.PROVENANCE_METADATA_FIELD)) {
String keyName = "model-metadata-" + e.getKey();
runProvenance.put(keyName, new StringProvenance(keyName, e.getValue()));
}
}
} catch (OrtException e) {
throw new IllegalArgumentException("Failed to load model and read metadata from path " + path, e);
}
ModelProvenance provenance = new ModelProvenance(ONNXExternalModel.class.getName(), now, datasetProvenance, trainerProvenance, runProvenance);
return new ONNXExternalModel<>("external-model", provenance, featureMap, outputInfo, featureMapping, modelArray, opts, inputName, featureTransformer, outputTransformer);
} catch (IOException e) {
throw new IllegalArgumentException("Unable to load model from path " + path, e);
}
}
use of com.oracle.labs.mlrg.olcut.provenance.Provenance in project tribuo by oracle.
the class SkeletalTrainerProvenance method extractProvenanceInfo.
protected static ExtractedInfo extractProvenanceInfo(Map<String, Provenance> map) {
String className;
String hostTypeStringName;
Map<String, Provenance> configuredParameters = new HashMap<>(map);
Map<String, PrimitiveProvenance<?>> instanceValues = new HashMap<>();
if (configuredParameters.containsKey(ObjectProvenance.CLASS_NAME)) {
className = configuredParameters.remove(ObjectProvenance.CLASS_NAME).toString();
} else {
throw new ProvenanceException("Failed to find class name when constructing SkeletalTrainerProvenance");
}
if (configuredParameters.containsKey(SkeletalConfiguredObjectProvenance.HOST_SHORT_NAME)) {
hostTypeStringName = configuredParameters.remove(SkeletalConfiguredObjectProvenance.HOST_SHORT_NAME).toString();
} else {
throw new ProvenanceException("Failed to find host type short name when constructing SkeletalTrainerProvenance");
}
if (configuredParameters.containsKey(TrainerProvenance.TRAIN_INVOCATION_COUNT)) {
Provenance tmpProv = configuredParameters.remove(TrainerProvenance.TRAIN_INVOCATION_COUNT);
if (tmpProv instanceof IntProvenance) {
instanceValues.put(TRAIN_INVOCATION_COUNT, (IntProvenance) tmpProv);
} else {
throw new ProvenanceException(TRAIN_INVOCATION_COUNT + " was not of type IntProvenance in class " + className);
}
} else {
throw new ProvenanceException("Failed to find invocation count when constructing SkeletalTrainerProvenance");
}
if (configuredParameters.containsKey(TrainerProvenance.IS_SEQUENCE)) {
Provenance tmpProv = configuredParameters.remove(TrainerProvenance.IS_SEQUENCE);
if (tmpProv instanceof BooleanProvenance) {
instanceValues.put(IS_SEQUENCE, (BooleanProvenance) tmpProv);
} else {
throw new ProvenanceException(IS_SEQUENCE + " was not of type BooleanProvenance in class " + className);
}
} else {
throw new ProvenanceException("Failed to find is-sequence when constructing SkeletalTrainerProvenance");
}
if (configuredParameters.containsKey(TrainerProvenance.TRIBUO_VERSION_STRING)) {
Provenance tmpProv = configuredParameters.remove(TrainerProvenance.TRIBUO_VERSION_STRING);
if (tmpProv instanceof StringProvenance) {
instanceValues.put(TRIBUO_VERSION_STRING, (StringProvenance) tmpProv);
} else {
throw new ProvenanceException(TRIBUO_VERSION_STRING + " was not of type StringProvenance in class " + className);
}
} else {
throw new ProvenanceException("Failed to find Tribuo version when constructing SkeletalTrainerProvenance");
}
return new ExtractedInfo(className, hostTypeStringName, configuredParameters, instanceValues);
}
use of com.oracle.labs.mlrg.olcut.provenance.Provenance in project tribuo by oracle.
the class ONNXExternalModel method getTribuoProvenance.
/**
* Returns the model provenance from the ONNX model if that
* model was trained in Tribuo.
* <p>
* Tribuo's ONNX export functionality stores the model provenance inside the
* ONNX file in the metadata field {@link ONNXExportable#PROVENANCE_METADATA_FIELD},
* and this method provides the access point for it.
* <p>
* Note it is different from the {@link Model#getProvenance()} call which
* returns information about the ONNX file itself, and when the {@code ONNXExternalModel}
* was created. It does not replace that provenance because instantiating this provenance
* may require classes which are not present on the classpath at deployment time.
*
* @return The model provenance from the original Tribuo training run, if it exists, and
* returns {@link Optional#empty()} otherwise.
*/
public Optional<ModelProvenance> getTribuoProvenance() {
try {
OnnxModelMetadata metadata = session.getMetadata();
Optional<String> value = metadata.getCustomMetadataValue(ONNXExportable.PROVENANCE_METADATA_FIELD);
if (value.isPresent()) {
Provenance prov = ONNXExportable.SERIALIZER.deserializeAndUnmarshal(value.get());
if (prov instanceof ModelProvenance) {
return Optional.of((ModelProvenance) prov);
} else {
logger.log(Level.WARNING, "Found invalid provenance object, " + prov.toString());
return Optional.empty();
}
} else {
return Optional.empty();
}
} catch (OrtException e) {
logger.log(Level.WARNING, "ORTException when reading session metadata", e);
return Optional.empty();
} catch (ProvenanceSerializationException e) {
logger.log(Level.WARNING, "Failed to parse provenance from value.", e);
return Optional.empty();
}
}
Aggregations