Search in sources :

Example 1 with ListProvenance

use of com.oracle.labs.mlrg.olcut.provenance.ListProvenance in project tribuo by oracle.

the class StripProvenance method cleanEnsembleProvenance.

/**
 * Creates a new ensemble provenance with the requested information removed.
 * @param old The old ensemble provenance.
 * @param memberProvenance The new member provenances.
 * @param provenanceHash The old ensemble provenance hash.
 * @param opt The program options.
 * @return The new ensemble provenance with the requested fields removed.
 */
private static EnsembleModelProvenance cleanEnsembleProvenance(EnsembleModelProvenance old, ListProvenance<ModelProvenance> memberProvenance, String provenanceHash, StripProvenanceOptions opt) {
    // Dataset provenance
    DatasetProvenance datasetProvenance;
    if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(DATASET)) {
        datasetProvenance = new EmptyDatasetProvenance();
    } else {
        datasetProvenance = old.getDatasetProvenance();
    }
    // Trainer provenance
    TrainerProvenance trainerProvenance;
    if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(TRAINER)) {
        trainerProvenance = new EmptyTrainerProvenance();
    } else {
        trainerProvenance = old.getTrainerProvenance();
    }
    // Instance provenance
    OffsetDateTime time;
    Map<String, Provenance> instanceProvenance;
    if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(INSTANCE)) {
        instanceProvenance = new HashMap<>();
        time = OffsetDateTime.MIN;
    } else {
        instanceProvenance = new HashMap<>(old.getInstanceProvenance().getMap());
        time = old.getTrainingTime();
    }
    if (opt.storeHash) {
        logger.info("Writing provenance hash into instance map.");
        instanceProvenance.put("original-provenance-hash", new HashProvenance(opt.hashType, "original-provenance-hash", provenanceHash));
    }
    return new EnsembleModelProvenance(old.getClassName(), time, datasetProvenance, trainerProvenance, instanceProvenance, memberProvenance);
}
Also used : HashProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance) ObjectProvenance(com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) HashProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) EmptyTrainerProvenance(org.tribuo.provenance.impl.EmptyTrainerProvenance) EmptyDatasetProvenance(org.tribuo.provenance.impl.EmptyDatasetProvenance) Provenance(com.oracle.labs.mlrg.olcut.provenance.Provenance) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance) ListProvenance(com.oracle.labs.mlrg.olcut.provenance.ListProvenance) EmptyTrainerProvenance(org.tribuo.provenance.impl.EmptyTrainerProvenance) OffsetDateTime(java.time.OffsetDateTime) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) EmptyDatasetProvenance(org.tribuo.provenance.impl.EmptyDatasetProvenance) EmptyTrainerProvenance(org.tribuo.provenance.impl.EmptyTrainerProvenance) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance) EmptyDatasetProvenance(org.tribuo.provenance.impl.EmptyDatasetProvenance)

Example 2 with ListProvenance

use of com.oracle.labs.mlrg.olcut.provenance.ListProvenance in project tribuo by oracle.

the class StripProvenance method convertModel.

/**
 * Creates a copy of the old model with the requested provenance removed.
 * @param oldModel The model to remove provenance from.
 * @param provenanceHash A hash of the old provenance.
 * @param opt The program options.
 * @param <T> The output type.
 * @return A copy of the model with redacted provenance.
 * @throws InvocationTargetException If the model doesn't expose a copy method (all models should do).
 * @throws IllegalAccessException If the model's copy method is not accessible.
 * @throws NoSuchMethodException If the model's copy method isn't present.
 */
// cast of model after call to copy which returns model.
@SuppressWarnings("unchecked")
private static <T extends Output<T>> ModelTuple<T> convertModel(Model<T> oldModel, String provenanceHash, StripProvenanceOptions opt) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException {
    if (oldModel instanceof EnsembleModel) {
        EnsembleModelProvenance oldProvenance = ((EnsembleModel<T>) oldModel).getProvenance();
        List<ModelProvenance> newProvenances = new ArrayList<>();
        List<Model<T>> newModels = new ArrayList<>();
        for (Model<T> e : ((EnsembleModel<T>) oldModel).getModels()) {
            ModelTuple<T> tuple = convertModel(e, provenanceHash, opt);
            newProvenances.add(tuple.provenance);
            newModels.add(tuple.model);
        }
        ListProvenance<ModelProvenance> listProv = new ListProvenance<>(newProvenances);
        EnsembleModelProvenance cleanedProvenance = cleanEnsembleProvenance(oldProvenance, listProv, provenanceHash, opt);
        Class<? extends Model> clazz = oldModel.getClass();
        Method copyMethod = clazz.getDeclaredMethod("copy", String.class, ModelProvenance.class, List.class);
        boolean accessible = copyMethod.isAccessible();
        copyMethod.setAccessible(true);
        String newName = oldModel.getName().isEmpty() ? "deprovenanced" : oldModel.getName() + "-deprovenanced";
        EnsembleModel<T> output = (EnsembleModel<T>) copyMethod.invoke(oldModel, newName, cleanedProvenance, newModels);
        copyMethod.setAccessible(accessible);
        return new ModelTuple<>(output, cleanedProvenance);
    } else {
        ModelProvenance oldProvenance = oldModel.getProvenance();
        ModelProvenance cleanedProvenance = cleanProvenance(oldProvenance, provenanceHash, opt);
        Class<? extends Model> clazz = oldModel.getClass();
        Method copyMethod = clazz.getDeclaredMethod("copy", String.class, ModelProvenance.class);
        boolean accessible = copyMethod.isAccessible();
        copyMethod.setAccessible(true);
        String newName = oldModel.getName().isEmpty() ? "deprovenanced" : oldModel.getName() + "-deprovenanced";
        Model<T> output = (Model<T>) copyMethod.invoke(oldModel, newName, cleanedProvenance);
        copyMethod.setAccessible(accessible);
        return new ModelTuple<>(output, cleanedProvenance);
    }
}
Also used : ModelProvenance(org.tribuo.provenance.ModelProvenance) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance) ArrayList(java.util.ArrayList) Method(java.lang.reflect.Method) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance) DATASET(org.tribuo.json.StripProvenance.ProvenanceTypes.DATASET) Model(org.tribuo.Model) EnsembleModel(org.tribuo.ensemble.EnsembleModel) ListProvenance(com.oracle.labs.mlrg.olcut.provenance.ListProvenance) EnsembleModel(org.tribuo.ensemble.EnsembleModel)

Example 3 with ListProvenance

use of com.oracle.labs.mlrg.olcut.provenance.ListProvenance in project tribuo by oracle.

the class CSVLoaderTest method testLoadNoHeader.

@Test
public void testLoadNoHeader() throws IOException {
    URL path = CSVLoader.class.getResource("/org/tribuo/data/csv/test.csv");
    CSVLoader<MockOutput> loader = new CSVLoader<>(new MockOutputFactory());
    // 
    // Currently, passing a the header into loader.load when the CSV has a header row will cause an error. This is
    // because CSVIterator does not skip the first line in this case.
    // TODO do we want this behavior?
    String[] header = new String[] { "A", "B", "C", "D", "RESPONSE" };
    assertThrows(NumberFormatException.class, () -> loader.load(Paths.get(path.toURI()), "RESPONSE", header));
    assertThrows(NumberFormatException.class, () -> loader.load(Paths.get(path.toURI()), Collections.singleton("RESPONSE"), header));
    // 
    // Test behavior when CSV file does not have a header row and the user instead supplies the header.
    URL noheader = CSVLoader.class.getResource("/org/tribuo/data/csv/test-noheader.csv");
    DataSource<MockOutput> source = loader.loadDataSource(noheader, "RESPONSE", header);
    // Check that the source persisted the headers in the provenance
    CSVDataSource.CSVDataSourceProvenance prov = (CSVDataSource.CSVDataSourceProvenance) source.getProvenance();
    Provenance headerProv = prov.getConfiguredParameters().get("headers");
    assertTrue(headerProv instanceof ListProvenance);
    @SuppressWarnings("unchecked") ListProvenance<StringProvenance> listProv = (ListProvenance<StringProvenance>) headerProv;
    assertEquals(header.length, listProv.getList().size());
    assertEquals(Arrays.asList(header), listProv.getList().stream().map(StringProvenance::getValue).collect(Collectors.toList()));
    // Check the data loaded correctly.
    checkDataTestCsv(source);
    checkDataTestCsv(loader.loadDataSource(noheader, Collections.singleton("RESPONSE"), header));
}
Also used : Provenance(com.oracle.labs.mlrg.olcut.provenance.Provenance) StringProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance) ListProvenance(com.oracle.labs.mlrg.olcut.provenance.ListProvenance) MockOutput(org.tribuo.test.MockOutput) MockOutputFactory(org.tribuo.test.MockOutputFactory) StringProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance) URL(java.net.URL) ListProvenance(com.oracle.labs.mlrg.olcut.provenance.ListProvenance) Test(org.junit.jupiter.api.Test)

Aggregations

ListProvenance (com.oracle.labs.mlrg.olcut.provenance.ListProvenance)3 Provenance (com.oracle.labs.mlrg.olcut.provenance.Provenance)2 EnsembleModelProvenance (org.tribuo.provenance.EnsembleModelProvenance)2 ModelProvenance (org.tribuo.provenance.ModelProvenance)2 ObjectProvenance (com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance)1 HashProvenance (com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance)1 StringProvenance (com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance)1 Method (java.lang.reflect.Method)1 URL (java.net.URL)1 OffsetDateTime (java.time.OffsetDateTime)1 ArrayList (java.util.ArrayList)1 Test (org.junit.jupiter.api.Test)1 Model (org.tribuo.Model)1 EnsembleModel (org.tribuo.ensemble.EnsembleModel)1 DATASET (org.tribuo.json.StripProvenance.ProvenanceTypes.DATASET)1 DatasetProvenance (org.tribuo.provenance.DatasetProvenance)1 TrainerProvenance (org.tribuo.provenance.TrainerProvenance)1 EmptyDatasetProvenance (org.tribuo.provenance.impl.EmptyDatasetProvenance)1 EmptyTrainerProvenance (org.tribuo.provenance.impl.EmptyTrainerProvenance)1 MockOutput (org.tribuo.test.MockOutput)1