use of com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance in project tribuo by oracle.
the class LibSVMDataSourceTest method testLibSVMSaving.
@Test
public void testLibSVMSaving() throws IOException {
MockOutputFactory factory = new MockOutputFactory();
URL dataFile = LibSVMDataSourceTest.class.getResource("/org/tribuo/datasource/test-1.libsvm");
LibSVMDataSource<MockOutput> source = new LibSVMDataSource<>(dataFile, factory);
File temp = File.createTempFile("tribuo-lib-svm-test", "libsvm");
temp.deleteOnExit();
MutableDataset<MockOutput> dataset = new MutableDataset<>(source);
try (PrintStream stream = new PrintStream(temp, StandardCharsets.UTF_8.name())) {
LibSVMDataSource.writeLibSVMFormat(dataset, stream, false, (MockOutput a) -> Integer.parseInt(a.label));
}
LibSVMDataSource<MockOutput> loadedSource = new LibSVMDataSource<>(temp.toPath(), factory);
assertTrue(compareDataSources(source, loadedSource), "Saved data source was not the same as the loaded one.");
// Now we check provenance path normalization on the saved file
// First generate a path with a relative element in it (i.e., ".")
Path newPath = temp.toPath().resolveSibling(Paths.get(".", temp.getName()));
// Load the datasource back in using the relativised path
LibSVMDataSource<MockOutput> newSource = new LibSVMDataSource<>(newPath, factory);
LibSVMDataSource.LibSVMDataSourceProvenance sourceProv = (LibSVMDataSource.LibSVMDataSourceProvenance) newSource.getProvenance();
// Extract the two provenance fields
// test code and the provenance exists
@SuppressWarnings("unchecked") URL sourceURL = ((PrimitiveProvenance<URL>) sourceProv.getConfiguredParameters().get("url")).getValue();
// test code and the provenance exists
@SuppressWarnings("unchecked") File sourceFile = ((PrimitiveProvenance<File>) sourceProv.getConfiguredParameters().get("path")).getValue();
// Assert they match
assertEquals(sourceFile.toPath().toUri().toURL(), sourceURL);
}
use of com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance in project tribuo by oracle.
the class SkeletalTrainerProvenance method extractProvenanceInfo.
protected static ExtractedInfo extractProvenanceInfo(Map<String, Provenance> map) {
String className;
String hostTypeStringName;
Map<String, Provenance> configuredParameters = new HashMap<>(map);
Map<String, PrimitiveProvenance<?>> instanceValues = new HashMap<>();
if (configuredParameters.containsKey(ObjectProvenance.CLASS_NAME)) {
className = configuredParameters.remove(ObjectProvenance.CLASS_NAME).toString();
} else {
throw new ProvenanceException("Failed to find class name when constructing SkeletalTrainerProvenance");
}
if (configuredParameters.containsKey(SkeletalConfiguredObjectProvenance.HOST_SHORT_NAME)) {
hostTypeStringName = configuredParameters.remove(SkeletalConfiguredObjectProvenance.HOST_SHORT_NAME).toString();
} else {
throw new ProvenanceException("Failed to find host type short name when constructing SkeletalTrainerProvenance");
}
if (configuredParameters.containsKey(TrainerProvenance.TRAIN_INVOCATION_COUNT)) {
Provenance tmpProv = configuredParameters.remove(TrainerProvenance.TRAIN_INVOCATION_COUNT);
if (tmpProv instanceof IntProvenance) {
instanceValues.put(TRAIN_INVOCATION_COUNT, (IntProvenance) tmpProv);
} else {
throw new ProvenanceException(TRAIN_INVOCATION_COUNT + " was not of type IntProvenance in class " + className);
}
} else {
throw new ProvenanceException("Failed to find invocation count when constructing SkeletalTrainerProvenance");
}
if (configuredParameters.containsKey(TrainerProvenance.IS_SEQUENCE)) {
Provenance tmpProv = configuredParameters.remove(TrainerProvenance.IS_SEQUENCE);
if (tmpProv instanceof BooleanProvenance) {
instanceValues.put(IS_SEQUENCE, (BooleanProvenance) tmpProv);
} else {
throw new ProvenanceException(IS_SEQUENCE + " was not of type BooleanProvenance in class " + className);
}
} else {
throw new ProvenanceException("Failed to find is-sequence when constructing SkeletalTrainerProvenance");
}
if (configuredParameters.containsKey(TrainerProvenance.TRIBUO_VERSION_STRING)) {
Provenance tmpProv = configuredParameters.remove(TrainerProvenance.TRIBUO_VERSION_STRING);
if (tmpProv instanceof StringProvenance) {
instanceValues.put(TRIBUO_VERSION_STRING, (StringProvenance) tmpProv);
} else {
throw new ProvenanceException(TRIBUO_VERSION_STRING + " was not of type StringProvenance in class " + className);
}
} else {
throw new ProvenanceException("Failed to find Tribuo version when constructing SkeletalTrainerProvenance");
}
return new ExtractedInfo(className, hostTypeStringName, configuredParameters, instanceValues);
}
Aggregations