use of com.oracle.labs.mlrg.olcut.provenance.ListProvenance in project tribuo by oracle.
the class StripProvenance method cleanEnsembleProvenance.
/**
* Creates a new ensemble provenance with the requested information removed.
* @param old The old ensemble provenance.
* @param memberProvenance The new member provenances.
* @param provenanceHash The old ensemble provenance hash.
* @param opt The program options.
* @return The new ensemble provenance with the requested fields removed.
*/
private static EnsembleModelProvenance cleanEnsembleProvenance(EnsembleModelProvenance old, ListProvenance<ModelProvenance> memberProvenance, String provenanceHash, StripProvenanceOptions opt) {
// Dataset provenance
DatasetProvenance datasetProvenance;
if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(DATASET)) {
datasetProvenance = new EmptyDatasetProvenance();
} else {
datasetProvenance = old.getDatasetProvenance();
}
// Trainer provenance
TrainerProvenance trainerProvenance;
if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(TRAINER)) {
trainerProvenance = new EmptyTrainerProvenance();
} else {
trainerProvenance = old.getTrainerProvenance();
}
// Instance provenance
OffsetDateTime time;
Map<String, Provenance> instanceProvenance;
if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(INSTANCE)) {
instanceProvenance = new HashMap<>();
time = OffsetDateTime.MIN;
} else {
instanceProvenance = new HashMap<>(old.getInstanceProvenance().getMap());
time = old.getTrainingTime();
}
if (opt.storeHash) {
logger.info("Writing provenance hash into instance map.");
instanceProvenance.put("original-provenance-hash", new HashProvenance(opt.hashType, "original-provenance-hash", provenanceHash));
}
return new EnsembleModelProvenance(old.getClassName(), time, datasetProvenance, trainerProvenance, instanceProvenance, memberProvenance);
}
use of com.oracle.labs.mlrg.olcut.provenance.ListProvenance in project tribuo by oracle.
the class StripProvenance method convertModel.
/**
* Creates a copy of the old model with the requested provenance removed.
* @param oldModel The model to remove provenance from.
* @param provenanceHash A hash of the old provenance.
* @param opt The program options.
* @param <T> The output type.
* @return A copy of the model with redacted provenance.
* @throws InvocationTargetException If the model doesn't expose a copy method (all models should do).
* @throws IllegalAccessException If the model's copy method is not accessible.
* @throws NoSuchMethodException If the model's copy method isn't present.
*/
// cast of model after call to copy which returns model.
@SuppressWarnings("unchecked")
private static <T extends Output<T>> ModelTuple<T> convertModel(Model<T> oldModel, String provenanceHash, StripProvenanceOptions opt) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException {
if (oldModel instanceof EnsembleModel) {
EnsembleModelProvenance oldProvenance = ((EnsembleModel<T>) oldModel).getProvenance();
List<ModelProvenance> newProvenances = new ArrayList<>();
List<Model<T>> newModels = new ArrayList<>();
for (Model<T> e : ((EnsembleModel<T>) oldModel).getModels()) {
ModelTuple<T> tuple = convertModel(e, provenanceHash, opt);
newProvenances.add(tuple.provenance);
newModels.add(tuple.model);
}
ListProvenance<ModelProvenance> listProv = new ListProvenance<>(newProvenances);
EnsembleModelProvenance cleanedProvenance = cleanEnsembleProvenance(oldProvenance, listProv, provenanceHash, opt);
Class<? extends Model> clazz = oldModel.getClass();
Method copyMethod = clazz.getDeclaredMethod("copy", String.class, ModelProvenance.class, List.class);
boolean accessible = copyMethod.isAccessible();
copyMethod.setAccessible(true);
String newName = oldModel.getName().isEmpty() ? "deprovenanced" : oldModel.getName() + "-deprovenanced";
EnsembleModel<T> output = (EnsembleModel<T>) copyMethod.invoke(oldModel, newName, cleanedProvenance, newModels);
copyMethod.setAccessible(accessible);
return new ModelTuple<>(output, cleanedProvenance);
} else {
ModelProvenance oldProvenance = oldModel.getProvenance();
ModelProvenance cleanedProvenance = cleanProvenance(oldProvenance, provenanceHash, opt);
Class<? extends Model> clazz = oldModel.getClass();
Method copyMethod = clazz.getDeclaredMethod("copy", String.class, ModelProvenance.class);
boolean accessible = copyMethod.isAccessible();
copyMethod.setAccessible(true);
String newName = oldModel.getName().isEmpty() ? "deprovenanced" : oldModel.getName() + "-deprovenanced";
Model<T> output = (Model<T>) copyMethod.invoke(oldModel, newName, cleanedProvenance);
copyMethod.setAccessible(accessible);
return new ModelTuple<>(output, cleanedProvenance);
}
}
use of com.oracle.labs.mlrg.olcut.provenance.ListProvenance in project tribuo by oracle.
the class CSVLoaderTest method testLoadNoHeader.
@Test
public void testLoadNoHeader() throws IOException {
URL path = CSVLoader.class.getResource("/org/tribuo/data/csv/test.csv");
CSVLoader<MockOutput> loader = new CSVLoader<>(new MockOutputFactory());
//
// Currently, passing a the header into loader.load when the CSV has a header row will cause an error. This is
// because CSVIterator does not skip the first line in this case.
// TODO do we want this behavior?
String[] header = new String[] { "A", "B", "C", "D", "RESPONSE" };
assertThrows(NumberFormatException.class, () -> loader.load(Paths.get(path.toURI()), "RESPONSE", header));
assertThrows(NumberFormatException.class, () -> loader.load(Paths.get(path.toURI()), Collections.singleton("RESPONSE"), header));
//
// Test behavior when CSV file does not have a header row and the user instead supplies the header.
URL noheader = CSVLoader.class.getResource("/org/tribuo/data/csv/test-noheader.csv");
DataSource<MockOutput> source = loader.loadDataSource(noheader, "RESPONSE", header);
// Check that the source persisted the headers in the provenance
CSVDataSource.CSVDataSourceProvenance prov = (CSVDataSource.CSVDataSourceProvenance) source.getProvenance();
Provenance headerProv = prov.getConfiguredParameters().get("headers");
assertTrue(headerProv instanceof ListProvenance);
@SuppressWarnings("unchecked") ListProvenance<StringProvenance> listProv = (ListProvenance<StringProvenance>) headerProv;
assertEquals(header.length, listProv.getList().size());
assertEquals(Arrays.asList(header), listProv.getList().stream().map(StringProvenance::getValue).collect(Collectors.toList()));
// Check the data loaded correctly.
checkDataTestCsv(source);
checkDataTestCsv(loader.loadDataSource(noheader, Collections.singleton("RESPONSE"), header));
}
Aggregations