Search in sources :

Example 1 with PrimitiveProvenance

use of com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance in project tribuo by oracle.

the class LibSVMDataSourceTest method testLibSVMSaving.

@Test
public void testLibSVMSaving() throws IOException {
    MockOutputFactory factory = new MockOutputFactory();
    URL dataFile = LibSVMDataSourceTest.class.getResource("/org/tribuo/datasource/test-1.libsvm");
    LibSVMDataSource<MockOutput> source = new LibSVMDataSource<>(dataFile, factory);
    File temp = File.createTempFile("tribuo-lib-svm-test", "libsvm");
    temp.deleteOnExit();
    MutableDataset<MockOutput> dataset = new MutableDataset<>(source);
    try (PrintStream stream = new PrintStream(temp, StandardCharsets.UTF_8.name())) {
        LibSVMDataSource.writeLibSVMFormat(dataset, stream, false, (MockOutput a) -> Integer.parseInt(a.label));
    }
    LibSVMDataSource<MockOutput> loadedSource = new LibSVMDataSource<>(temp.toPath(), factory);
    assertTrue(compareDataSources(source, loadedSource), "Saved data source was not the same as the loaded one.");
    // Now we check provenance path normalization on the saved file
    // First generate a path with a relative element in it (i.e., ".")
    Path newPath = temp.toPath().resolveSibling(Paths.get(".", temp.getName()));
    // Load the datasource back in using the relativised path
    LibSVMDataSource<MockOutput> newSource = new LibSVMDataSource<>(newPath, factory);
    LibSVMDataSource.LibSVMDataSourceProvenance sourceProv = (LibSVMDataSource.LibSVMDataSourceProvenance) newSource.getProvenance();
    // Extract the two provenance fields
    // test code and the provenance exists
    @SuppressWarnings("unchecked") URL sourceURL = ((PrimitiveProvenance<URL>) sourceProv.getConfiguredParameters().get("url")).getValue();
    // test code and the provenance exists
    @SuppressWarnings("unchecked") File sourceFile = ((PrimitiveProvenance<File>) sourceProv.getConfiguredParameters().get("path")).getValue();
    // Assert they match
    assertEquals(sourceFile.toPath().toUri().toURL(), sourceURL);
}
Also used : Path(java.nio.file.Path) PrintStream(java.io.PrintStream) MockOutput(org.tribuo.test.MockOutput) MockOutputFactory(org.tribuo.test.MockOutputFactory) URL(java.net.URL) PrimitiveProvenance(com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance) File(java.io.File) MutableDataset(org.tribuo.MutableDataset) Test(org.junit.jupiter.api.Test)

Example 2 with PrimitiveProvenance

use of com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance in project tribuo by oracle.

the class SkeletalTrainerProvenance method extractProvenanceInfo.

protected static ExtractedInfo extractProvenanceInfo(Map<String, Provenance> map) {
    String className;
    String hostTypeStringName;
    Map<String, Provenance> configuredParameters = new HashMap<>(map);
    Map<String, PrimitiveProvenance<?>> instanceValues = new HashMap<>();
    if (configuredParameters.containsKey(ObjectProvenance.CLASS_NAME)) {
        className = configuredParameters.remove(ObjectProvenance.CLASS_NAME).toString();
    } else {
        throw new ProvenanceException("Failed to find class name when constructing SkeletalTrainerProvenance");
    }
    if (configuredParameters.containsKey(SkeletalConfiguredObjectProvenance.HOST_SHORT_NAME)) {
        hostTypeStringName = configuredParameters.remove(SkeletalConfiguredObjectProvenance.HOST_SHORT_NAME).toString();
    } else {
        throw new ProvenanceException("Failed to find host type short name when constructing SkeletalTrainerProvenance");
    }
    if (configuredParameters.containsKey(TrainerProvenance.TRAIN_INVOCATION_COUNT)) {
        Provenance tmpProv = configuredParameters.remove(TrainerProvenance.TRAIN_INVOCATION_COUNT);
        if (tmpProv instanceof IntProvenance) {
            instanceValues.put(TRAIN_INVOCATION_COUNT, (IntProvenance) tmpProv);
        } else {
            throw new ProvenanceException(TRAIN_INVOCATION_COUNT + " was not of type IntProvenance in class " + className);
        }
    } else {
        throw new ProvenanceException("Failed to find invocation count when constructing SkeletalTrainerProvenance");
    }
    if (configuredParameters.containsKey(TrainerProvenance.IS_SEQUENCE)) {
        Provenance tmpProv = configuredParameters.remove(TrainerProvenance.IS_SEQUENCE);
        if (tmpProv instanceof BooleanProvenance) {
            instanceValues.put(IS_SEQUENCE, (BooleanProvenance) tmpProv);
        } else {
            throw new ProvenanceException(IS_SEQUENCE + " was not of type BooleanProvenance in class " + className);
        }
    } else {
        throw new ProvenanceException("Failed to find is-sequence when constructing SkeletalTrainerProvenance");
    }
    if (configuredParameters.containsKey(TrainerProvenance.TRIBUO_VERSION_STRING)) {
        Provenance tmpProv = configuredParameters.remove(TrainerProvenance.TRIBUO_VERSION_STRING);
        if (tmpProv instanceof StringProvenance) {
            instanceValues.put(TRIBUO_VERSION_STRING, (StringProvenance) tmpProv);
        } else {
            throw new ProvenanceException(TRIBUO_VERSION_STRING + " was not of type StringProvenance in class " + className);
        }
    } else {
        throw new ProvenanceException("Failed to find Tribuo version when constructing SkeletalTrainerProvenance");
    }
    return new ExtractedInfo(className, hostTypeStringName, configuredParameters, instanceValues);
}
Also used : PrimitiveProvenance(com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance) BooleanProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance) ObjectProvenance(com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance) Provenance(com.oracle.labs.mlrg.olcut.provenance.Provenance) BooleanProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance) StringProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance) SkeletalConfiguredObjectProvenance(com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance) PrimitiveProvenance(com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance) IntProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance) ProvenanceException(com.oracle.labs.mlrg.olcut.provenance.ProvenanceException) HashMap(java.util.HashMap) IntProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance) StringProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance)

Aggregations

PrimitiveProvenance (com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance)2 ObjectProvenance (com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance)1 Provenance (com.oracle.labs.mlrg.olcut.provenance.Provenance)1 ProvenanceException (com.oracle.labs.mlrg.olcut.provenance.ProvenanceException)1 SkeletalConfiguredObjectProvenance (com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance)1 BooleanProvenance (com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance)1 IntProvenance (com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance)1 StringProvenance (com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance)1 File (java.io.File)1 PrintStream (java.io.PrintStream)1 URL (java.net.URL)1 Path (java.nio.file.Path)1 HashMap (java.util.HashMap)1 Test (org.junit.jupiter.api.Test)1 MutableDataset (org.tribuo.MutableDataset)1 MockOutput (org.tribuo.test.MockOutput)1 MockOutputFactory (org.tribuo.test.MockOutputFactory)1