Search in sources :

Example 1 with Provenance

use of com.oracle.labs.mlrg.olcut.provenance.Provenance in project tribuo by oracle.

the class StripProvenance method cleanEnsembleProvenance.

/**
 * Creates a new ensemble provenance with the requested information removed.
 * @param old The old ensemble provenance.
 * @param memberProvenance The new member provenances.
 * @param provenanceHash The old ensemble provenance hash.
 * @param opt The program options.
 * @return The new ensemble provenance with the requested fields removed.
 */
private static EnsembleModelProvenance cleanEnsembleProvenance(EnsembleModelProvenance old, ListProvenance<ModelProvenance> memberProvenance, String provenanceHash, StripProvenanceOptions opt) {
    // Dataset provenance
    DatasetProvenance datasetProvenance;
    if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(DATASET)) {
        datasetProvenance = new EmptyDatasetProvenance();
    } else {
        datasetProvenance = old.getDatasetProvenance();
    }
    // Trainer provenance
    TrainerProvenance trainerProvenance;
    if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(TRAINER)) {
        trainerProvenance = new EmptyTrainerProvenance();
    } else {
        trainerProvenance = old.getTrainerProvenance();
    }
    // Instance provenance
    OffsetDateTime time;
    Map<String, Provenance> instanceProvenance;
    if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(INSTANCE)) {
        instanceProvenance = new HashMap<>();
        time = OffsetDateTime.MIN;
    } else {
        instanceProvenance = new HashMap<>(old.getInstanceProvenance().getMap());
        time = old.getTrainingTime();
    }
    if (opt.storeHash) {
        logger.info("Writing provenance hash into instance map.");
        instanceProvenance.put("original-provenance-hash", new HashProvenance(opt.hashType, "original-provenance-hash", provenanceHash));
    }
    return new EnsembleModelProvenance(old.getClassName(), time, datasetProvenance, trainerProvenance, instanceProvenance, memberProvenance);
}
Also used : HashProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance) ObjectProvenance(com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) HashProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) EmptyTrainerProvenance(org.tribuo.provenance.impl.EmptyTrainerProvenance) EmptyDatasetProvenance(org.tribuo.provenance.impl.EmptyDatasetProvenance) Provenance(com.oracle.labs.mlrg.olcut.provenance.Provenance) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance) ListProvenance(com.oracle.labs.mlrg.olcut.provenance.ListProvenance) EmptyTrainerProvenance(org.tribuo.provenance.impl.EmptyTrainerProvenance) OffsetDateTime(java.time.OffsetDateTime) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) EmptyDatasetProvenance(org.tribuo.provenance.impl.EmptyDatasetProvenance) EmptyTrainerProvenance(org.tribuo.provenance.impl.EmptyTrainerProvenance) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance) EmptyDatasetProvenance(org.tribuo.provenance.impl.EmptyDatasetProvenance)

Example 2 with Provenance

use of com.oracle.labs.mlrg.olcut.provenance.Provenance in project tribuo by oracle.

the class CSVLoaderTest method testLoadNoHeader.

@Test
public void testLoadNoHeader() throws IOException {
    URL path = CSVLoader.class.getResource("/org/tribuo/data/csv/test.csv");
    CSVLoader<MockOutput> loader = new CSVLoader<>(new MockOutputFactory());
    // 
    // Currently, passing a the header into loader.load when the CSV has a header row will cause an error. This is
    // because CSVIterator does not skip the first line in this case.
    // TODO do we want this behavior?
    String[] header = new String[] { "A", "B", "C", "D", "RESPONSE" };
    assertThrows(NumberFormatException.class, () -> loader.load(Paths.get(path.toURI()), "RESPONSE", header));
    assertThrows(NumberFormatException.class, () -> loader.load(Paths.get(path.toURI()), Collections.singleton("RESPONSE"), header));
    // 
    // Test behavior when CSV file does not have a header row and the user instead supplies the header.
    URL noheader = CSVLoader.class.getResource("/org/tribuo/data/csv/test-noheader.csv");
    DataSource<MockOutput> source = loader.loadDataSource(noheader, "RESPONSE", header);
    // Check that the source persisted the headers in the provenance
    CSVDataSource.CSVDataSourceProvenance prov = (CSVDataSource.CSVDataSourceProvenance) source.getProvenance();
    Provenance headerProv = prov.getConfiguredParameters().get("headers");
    assertTrue(headerProv instanceof ListProvenance);
    @SuppressWarnings("unchecked") ListProvenance<StringProvenance> listProv = (ListProvenance<StringProvenance>) headerProv;
    assertEquals(header.length, listProv.getList().size());
    assertEquals(Arrays.asList(header), listProv.getList().stream().map(StringProvenance::getValue).collect(Collectors.toList()));
    // Check the data loaded correctly.
    checkDataTestCsv(source);
    checkDataTestCsv(loader.loadDataSource(noheader, Collections.singleton("RESPONSE"), header));
}
Also used : Provenance(com.oracle.labs.mlrg.olcut.provenance.Provenance) StringProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance) ListProvenance(com.oracle.labs.mlrg.olcut.provenance.ListProvenance) MockOutput(org.tribuo.test.MockOutput) MockOutputFactory(org.tribuo.test.MockOutputFactory) StringProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance) URL(java.net.URL) ListProvenance(com.oracle.labs.mlrg.olcut.provenance.ListProvenance) Test(org.junit.jupiter.api.Test)

Example 3 with Provenance

use of com.oracle.labs.mlrg.olcut.provenance.Provenance in project tribuo by oracle.

the class ONNXExternalModel method createOnnxModel.

/**
 * Creates an {@code ONNXExternalModel} by loading the model from disk.
 *
 * @param factory            The output factory to use.
 * @param featureMapping     The feature mapping between Tribuo names and ONNX integer ids.
 * @param outputMapping      The output mapping between Tribuo outputs and ONNX integer ids.
 * @param featureTransformer The transformation function for the features.
 * @param outputTransformer  The transformation function for the outputs.
 * @param opts               The session options for the ONNX model.
 * @param path               The model path.
 * @param inputName          The name of the input node.
 * @param <T>                The type of the output.
 * @return An ONNXExternalModel ready to score new inputs.
 * @throws OrtException If the onnx-runtime native library call failed.
 */
public static <T extends Output<T>> ONNXExternalModel<T> createOnnxModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T, Integer> outputMapping, ExampleTransformer featureTransformer, OutputTransformer<T> outputTransformer, OrtSession.SessionOptions opts, Path path, String inputName) throws OrtException {
    try {
        byte[] modelArray = Files.readAllBytes(path);
        URL provenanceLocation = path.toUri().toURL();
        ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet());
        ImmutableOutputInfo<T> outputInfo = ExternalModel.createOutputInfo(factory, outputMapping);
        OffsetDateTime now = OffsetDateTime.now();
        ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(provenanceLocation);
        DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data", factory, false, featureMapping.size(), outputMapping.size());
        HashMap<String, Provenance> runProvenance = new HashMap<>();
        runProvenance.put("input-name", new StringProvenance("input-name", inputName));
        try (OrtEnvironment env = OrtEnvironment.getEnvironment();
            OrtSession session = env.createSession(modelArray)) {
            OnnxModelMetadata metadata = session.getMetadata();
            runProvenance.put("model-producer", new StringProvenance("model-producer", metadata.getProducerName()));
            runProvenance.put("model-domain", new StringProvenance("model-domain", metadata.getDomain()));
            runProvenance.put("model-description", new StringProvenance("model-description", metadata.getDescription()));
            runProvenance.put("model-graphname", new StringProvenance("model-graphname", metadata.getGraphName()));
            runProvenance.put("model-version", new LongProvenance("model-version", metadata.getVersion()));
            for (Map.Entry<String, String> e : metadata.getCustomMetadata().entrySet()) {
                if (!e.getKey().equals(ONNXExportable.PROVENANCE_METADATA_FIELD)) {
                    String keyName = "model-metadata-" + e.getKey();
                    runProvenance.put(keyName, new StringProvenance(keyName, e.getValue()));
                }
            }
        } catch (OrtException e) {
            throw new IllegalArgumentException("Failed to load model and read metadata from path " + path, e);
        }
        ModelProvenance provenance = new ModelProvenance(ONNXExternalModel.class.getName(), now, datasetProvenance, trainerProvenance, runProvenance);
        return new ONNXExternalModel<>("external-model", provenance, featureMap, outputInfo, featureMapping, modelArray, opts, inputName, featureTransformer, outputTransformer);
    } catch (IOException e) {
        throw new IllegalArgumentException("Unable to load model from path " + path, e);
    }
}
Also used : ModelProvenance(org.tribuo.provenance.ModelProvenance) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) ExternalTrainerProvenance(org.tribuo.interop.ExternalTrainerProvenance) Provenance(com.oracle.labs.mlrg.olcut.provenance.Provenance) StringProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance) LongProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance) ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) ExternalTrainerProvenance(org.tribuo.interop.ExternalTrainerProvenance) HashMap(java.util.HashMap) OrtException(ai.onnxruntime.OrtException) URL(java.net.URL) OrtEnvironment(ai.onnxruntime.OrtEnvironment) OnnxModelMetadata(ai.onnxruntime.OnnxModelMetadata) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) LongProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance) ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) StringProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) IOException(java.io.IOException) OrtSession(ai.onnxruntime.OrtSession) OffsetDateTime(java.time.OffsetDateTime) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) HashMap(java.util.HashMap) Map(java.util.Map)

Example 4 with Provenance

use of com.oracle.labs.mlrg.olcut.provenance.Provenance in project tribuo by oracle.

the class SkeletalTrainerProvenance method extractProvenanceInfo.

protected static ExtractedInfo extractProvenanceInfo(Map<String, Provenance> map) {
    String className;
    String hostTypeStringName;
    Map<String, Provenance> configuredParameters = new HashMap<>(map);
    Map<String, PrimitiveProvenance<?>> instanceValues = new HashMap<>();
    if (configuredParameters.containsKey(ObjectProvenance.CLASS_NAME)) {
        className = configuredParameters.remove(ObjectProvenance.CLASS_NAME).toString();
    } else {
        throw new ProvenanceException("Failed to find class name when constructing SkeletalTrainerProvenance");
    }
    if (configuredParameters.containsKey(SkeletalConfiguredObjectProvenance.HOST_SHORT_NAME)) {
        hostTypeStringName = configuredParameters.remove(SkeletalConfiguredObjectProvenance.HOST_SHORT_NAME).toString();
    } else {
        throw new ProvenanceException("Failed to find host type short name when constructing SkeletalTrainerProvenance");
    }
    if (configuredParameters.containsKey(TrainerProvenance.TRAIN_INVOCATION_COUNT)) {
        Provenance tmpProv = configuredParameters.remove(TrainerProvenance.TRAIN_INVOCATION_COUNT);
        if (tmpProv instanceof IntProvenance) {
            instanceValues.put(TRAIN_INVOCATION_COUNT, (IntProvenance) tmpProv);
        } else {
            throw new ProvenanceException(TRAIN_INVOCATION_COUNT + " was not of type IntProvenance in class " + className);
        }
    } else {
        throw new ProvenanceException("Failed to find invocation count when constructing SkeletalTrainerProvenance");
    }
    if (configuredParameters.containsKey(TrainerProvenance.IS_SEQUENCE)) {
        Provenance tmpProv = configuredParameters.remove(TrainerProvenance.IS_SEQUENCE);
        if (tmpProv instanceof BooleanProvenance) {
            instanceValues.put(IS_SEQUENCE, (BooleanProvenance) tmpProv);
        } else {
            throw new ProvenanceException(IS_SEQUENCE + " was not of type BooleanProvenance in class " + className);
        }
    } else {
        throw new ProvenanceException("Failed to find is-sequence when constructing SkeletalTrainerProvenance");
    }
    if (configuredParameters.containsKey(TrainerProvenance.TRIBUO_VERSION_STRING)) {
        Provenance tmpProv = configuredParameters.remove(TrainerProvenance.TRIBUO_VERSION_STRING);
        if (tmpProv instanceof StringProvenance) {
            instanceValues.put(TRIBUO_VERSION_STRING, (StringProvenance) tmpProv);
        } else {
            throw new ProvenanceException(TRIBUO_VERSION_STRING + " was not of type StringProvenance in class " + className);
        }
    } else {
        throw new ProvenanceException("Failed to find Tribuo version when constructing SkeletalTrainerProvenance");
    }
    return new ExtractedInfo(className, hostTypeStringName, configuredParameters, instanceValues);
}
Also used : PrimitiveProvenance(com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance) BooleanProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance) ObjectProvenance(com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance) Provenance(com.oracle.labs.mlrg.olcut.provenance.Provenance) BooleanProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance) StringProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance) SkeletalConfiguredObjectProvenance(com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance) PrimitiveProvenance(com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance) IntProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance) ProvenanceException(com.oracle.labs.mlrg.olcut.provenance.ProvenanceException) HashMap(java.util.HashMap) IntProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance) StringProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance)

Example 5 with Provenance

use of com.oracle.labs.mlrg.olcut.provenance.Provenance in project tribuo by oracle.

the class ONNXExternalModel method getTribuoProvenance.

/**
 * Returns the model provenance from the ONNX model if that
 * model was trained in Tribuo.
 * <p>
 * Tribuo's ONNX export functionality stores the model provenance inside the
 * ONNX file in the metadata field {@link ONNXExportable#PROVENANCE_METADATA_FIELD},
 * and this method provides the access point for it.
 * <p>
 * Note it is different from the {@link Model#getProvenance()} call which
 * returns information about the ONNX file itself, and when the {@code ONNXExternalModel}
 * was created. It does not replace that provenance because instantiating this provenance
 * may require classes which are not present on the classpath at deployment time.
 *
 * @return The model provenance from the original Tribuo training run, if it exists, and
 * returns {@link Optional#empty()} otherwise.
 */
public Optional<ModelProvenance> getTribuoProvenance() {
    try {
        OnnxModelMetadata metadata = session.getMetadata();
        Optional<String> value = metadata.getCustomMetadataValue(ONNXExportable.PROVENANCE_METADATA_FIELD);
        if (value.isPresent()) {
            Provenance prov = ONNXExportable.SERIALIZER.deserializeAndUnmarshal(value.get());
            if (prov instanceof ModelProvenance) {
                return Optional.of((ModelProvenance) prov);
            } else {
                logger.log(Level.WARNING, "Found invalid provenance object, " + prov.toString());
                return Optional.empty();
            }
        } else {
            return Optional.empty();
        }
    } catch (OrtException e) {
        logger.log(Level.WARNING, "ORTException when reading session metadata", e);
        return Optional.empty();
    } catch (ProvenanceSerializationException e) {
        logger.log(Level.WARNING, "Failed to parse provenance from value.", e);
        return Optional.empty();
    }
}
Also used : ModelProvenance(org.tribuo.provenance.ModelProvenance) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) ExternalTrainerProvenance(org.tribuo.interop.ExternalTrainerProvenance) Provenance(com.oracle.labs.mlrg.olcut.provenance.Provenance) StringProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance) LongProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance) ExternalDatasetProvenance(org.tribuo.interop.ExternalDatasetProvenance) OnnxModelMetadata(ai.onnxruntime.OnnxModelMetadata) ModelProvenance(org.tribuo.provenance.ModelProvenance) OrtException(ai.onnxruntime.OrtException) ProvenanceSerializationException(com.oracle.labs.mlrg.olcut.provenance.io.ProvenanceSerializationException)

Aggregations

Provenance (com.oracle.labs.mlrg.olcut.provenance.Provenance)7 StringProvenance (com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance)5 DatasetProvenance (org.tribuo.provenance.DatasetProvenance)5 ModelProvenance (org.tribuo.provenance.ModelProvenance)5 OffsetDateTime (java.time.OffsetDateTime)4 ListProvenance (com.oracle.labs.mlrg.olcut.provenance.ListProvenance)3 ObjectProvenance (com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance)3 HashMap (java.util.HashMap)3 ExternalDatasetProvenance (org.tribuo.interop.ExternalDatasetProvenance)3 ExternalTrainerProvenance (org.tribuo.interop.ExternalTrainerProvenance)3 OnnxModelMetadata (ai.onnxruntime.OnnxModelMetadata)2 OrtException (ai.onnxruntime.OrtException)2 HashProvenance (com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance)2 LongProvenance (com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance)2 IOException (java.io.IOException)2 URL (java.net.URL)2 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)2 EnsembleModelProvenance (org.tribuo.provenance.EnsembleModelProvenance)2 TrainerProvenance (org.tribuo.provenance.TrainerProvenance)2 EmptyDatasetProvenance (org.tribuo.provenance.impl.EmptyDatasetProvenance)2