Search in sources :

Example 1 with CSVLoader

use of org.tribuo.data.csv.CSVLoader in project tribuo by oracle.

the class CSVLoaderWithMultiOutputsTest method loadsMultiRegressor.

@Test
public void loadsMultiRegressor() throws IOException {
    Path path = Resources.copyResourceToTmp("/org/tribuo/tests/csv/multioutput-regression.csv");
    CSVLoader<Regressor> loader = new CSVLoader<>(new RegressionFactory());
    String[] responseNames = new String[] { "R1", "R2" };
    MutableDataset<Regressor> data = loader.load(path, new HashSet<>(Arrays.asList(responseNames)));
    assertEquals(5, data.size());
    Example<Regressor> x0 = data.getExample(0);
    assertArrayEquals(responseNames, x0.getOutput().getNames());
    assertArrayEquals(new double[] { 0.1, 0.2 }, x0.getOutput().getValues());
    Example<Regressor> x1 = data.getExample(1);
    assertArrayEquals(responseNames, x1.getOutput().getNames());
    assertArrayEquals(new double[] { 0.0, 0.0 }, x1.getOutput().getValues());
}
Also used : Path(java.nio.file.Path) CSVLoader(org.tribuo.data.csv.CSVLoader) RegressionFactory(org.tribuo.regression.RegressionFactory) Regressor(org.tribuo.regression.Regressor) Test(org.junit.jupiter.api.Test)

Example 2 with CSVLoader

use of org.tribuo.data.csv.CSVLoader in project tribuo by oracle.

the class CSVSaverWithMultiOutputsTest method loaderCanReconstructSavedMultiLabel.

@Test
public void loaderCanReconstructSavedMultiLabel() throws IOException {
    Path path = Resources.copyResourceToTmp("/org/tribuo/tests/csv/multilabel.csv");
    Set<String> responses = new HashSet<>(Arrays.asList("R1", "R2"));
    // 
    // Load the csv
    CSVLoader<MultiLabel> loader = new CSVLoader<>(new MultiLabelFactory());
    MutableDataset<MultiLabel> before = loader.load(path, responses);
    // 
    // Save the dataset
    File tmpFile = File.createTempFile("tribuo-csv-test", "csv");
    tmpFile.deleteOnExit();
    Path tmp = tmpFile.toPath();
    new CSVSaver().save(tmp, before, responses);
    // 
    // Reload and check that before & after are equivalent.
    MutableDataset<MultiLabel> after = loader.load(tmp, responses);
    // 
    // TODO: better check for dataset equivalence?
    assertEquals(before.getData(), after.getData());
    assertEquals(before.getOutputIDInfo().size(), after.getOutputIDInfo().size());
    assertEquals(before.getFeatureIDMap().size(), after.getFeatureIDMap().size());
}
Also used : Path(java.nio.file.Path) MultiLabel(org.tribuo.multilabel.MultiLabel) CSVLoader(org.tribuo.data.csv.CSVLoader) MultiLabelFactory(org.tribuo.multilabel.MultiLabelFactory) CSVSaver(org.tribuo.data.csv.CSVSaver) File(java.io.File) HashSet(java.util.HashSet) Test(org.junit.jupiter.api.Test)

Example 3 with CSVLoader

use of org.tribuo.data.csv.CSVLoader in project tribuo by oracle.

the class CSVSaverWithMultiOutputsTest method savesMultipleRegression.

@Test
public void savesMultipleRegression() throws IOException {
    String[] vars = new String[] { "dim1", "dim2" };
    Set<String> responseNames = new HashSet<>(Arrays.asList("dim1", "dim2"));
    RegressionFactory factory = new RegressionFactory();
    MutableDataset<Regressor> before = new MutableDataset<>(null, factory);
    ArrayExample<Regressor> e = new ArrayExample<>(new Regressor(vars, new double[] { 0.1, 0.0 }));
    e.add(new Feature("A", 1.0));
    e.add(new Feature("B", 0.0));
    e.add(new Feature("C", 0.0));
    before.add(e);
    ArrayExample<Regressor> b = new ArrayExample<>(new Regressor(vars, new double[] { 0.0, 0.0 }));
    b.add(new Feature("A", 1.0));
    b.add(new Feature("B", 0.0));
    b.add(new Feature("C", 0.1));
    before.add(b);
    CSVSaver saver = new CSVSaver();
    File tmpFile = File.createTempFile("tribuo-csv-test", "csv");
    tmpFile.deleteOnExit();
    Path tmp = tmpFile.toPath();
    saver.save(tmp, before, responseNames);
    // TODO use this to compare literal saver outputs
    // ByteArrayOutputStream baos = new ByteArrayOutputStream();
    // saver.save(baos, before, responseNames);
    // baos.flush();
    // System.out.println(new String(baos.toByteArray()));
    CSVLoader<Regressor> loader = new CSVLoader<>(factory);
    MutableDataset<Regressor> after = loader.load(tmp, responseNames);
    assertEquals(before.getData(), after.getData());
    assertEquals(before.getOutputIDInfo().size(), after.getOutputIDInfo().size());
    assertEquals(before.getFeatureIDMap().size(), after.getFeatureIDMap().size());
}
Also used : Path(java.nio.file.Path) RegressionFactory(org.tribuo.regression.RegressionFactory) CSVLoader(org.tribuo.data.csv.CSVLoader) Feature(org.tribuo.Feature) ArrayExample(org.tribuo.impl.ArrayExample) Regressor(org.tribuo.regression.Regressor) CSVSaver(org.tribuo.data.csv.CSVSaver) MutableDataset(org.tribuo.MutableDataset) File(java.io.File) HashSet(java.util.HashSet) Test(org.junit.jupiter.api.Test)

Example 4 with CSVLoader

use of org.tribuo.data.csv.CSVLoader in project tribuo by oracle.

the class CSVSaverWithMultiOutputsTest method savesMultiLabel.

@Test
public void savesMultiLabel() throws IOException {
    Set<String> responseNames = new HashSet<>(Arrays.asList("MONKEY", "PUZZLE", "TREE"));
    MultiLabelFactory factory = new MultiLabelFactory();
    MutableDataset<MultiLabel> before = new MutableDataset<>(null, factory);
    ArrayExample<MultiLabel> e = new ArrayExample<>(factory.generateOutput("MONKEY"));
    e.add(new Feature("A-MONKEY", 1.0));
    e.add(new Feature("B-PUZZLE", 0.0));
    e.add(new Feature("C-TREE", 0.0));
    before.add(e);
    ArrayExample<MultiLabel> b = new ArrayExample<>(factory.generateOutput("MONKEY,TREE"));
    b.add(new Feature("A-MONKEY", 1.0));
    b.add(new Feature("C-TREE", 1.0));
    CSVSaver saver = new CSVSaver();
    File tmpFile = File.createTempFile("tribuo-csv-test", "csv");
    tmpFile.deleteOnExit();
    Path tmp = tmpFile.toPath();
    saver.save(tmp, before, responseNames);
    // TODO use this to compare literal saver outputs
    // ByteArrayOutputStream baos = new ByteArrayOutputStream();
    // saver.save(baos, before, responseNames);
    // baos.flush();
    // System.out.println(new String(baos.toByteArray()));
    CSVLoader<MultiLabel> loader = new CSVLoader<>(factory);
    MutableDataset<MultiLabel> after = loader.load(tmp, responseNames);
    assertEquals(before.getData(), after.getData());
    assertEquals(before.getOutputIDInfo().size(), after.getOutputIDInfo().size());
    assertEquals(before.getFeatureIDMap().size(), after.getFeatureIDMap().size());
}
Also used : Path(java.nio.file.Path) MultiLabel(org.tribuo.multilabel.MultiLabel) CSVLoader(org.tribuo.data.csv.CSVLoader) Feature(org.tribuo.Feature) ArrayExample(org.tribuo.impl.ArrayExample) MultiLabelFactory(org.tribuo.multilabel.MultiLabelFactory) CSVSaver(org.tribuo.data.csv.CSVSaver) MutableDataset(org.tribuo.MutableDataset) File(java.io.File) HashSet(java.util.HashSet) Test(org.junit.jupiter.api.Test)

Example 5 with CSVLoader

use of org.tribuo.data.csv.CSVLoader in project tribuo by oracle.

the class DataOptions method load.

/**
 * Loads the training and testing data from {@link #trainingPath} and {@link #testingPath}
 * according to the other parameters specified in this class.
 * @param outputFactory The output factory to use to process the inputs.
 * @param <T> The dataset output type.
 * @return A pair containing the training and testing datasets. The training dataset is element 'A' and the
 * testing dataset is element 'B'.
 * @throws IOException If the paths could not be loaded.
 */
public <T extends Output<T>> Pair<Dataset<T>, Dataset<T>> load(OutputFactory<T> outputFactory) throws IOException {
    logger.info(String.format("Loading data from %s", trainingPath));
    Dataset<T> train;
    Dataset<T> test;
    char separator;
    switch(inputFormat) {
        case SERIALIZED:
            // 
            // Load Tribuo serialised datasets.
            logger.info("Deserialising dataset from " + trainingPath);
            try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(trainingPath.toFile())));
                ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(testingPath.toFile())))) {
                @SuppressWarnings("unchecked") Dataset<T> tmp = (Dataset<T>) ois.readObject();
                train = tmp;
                if (minCount > 0) {
                    logger.info("Found " + train.getFeatureIDMap().size() + " features");
                    logger.info("Removing features that occur fewer than " + minCount + " times.");
                    train = new MinimumCardinalityDataset<>(train, minCount);
                }
                logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
                logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
                @SuppressWarnings("unchecked") Dataset<T> deserTest = (Dataset<T>) oits.readObject();
                test = new ImmutableDataset<>(deserTest, deserTest.getSourceProvenance(), deserTest.getOutputFactory(), train.getFeatureIDMap(), train.getOutputIDInfo(), true);
            } catch (ClassNotFoundException e) {
                throw new IllegalArgumentException("Unknown class in serialised files", e);
            }
            break;
        case LIBSVM:
            // 
            // Load the libsvm text-based data format.
            LibSVMDataSource<T> trainSVMSource = new LibSVMDataSource<>(trainingPath, outputFactory);
            train = new MutableDataset<>(trainSVMSource);
            boolean zeroIndexed = trainSVMSource.isZeroIndexed();
            int maxFeatureID = trainSVMSource.getMaxFeatureID();
            if (minCount > 0) {
                logger.info("Removing features that occur fewer than " + minCount + " times.");
                train = new MinimumCardinalityDataset<>(train, minCount);
            }
            logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
            logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
            test = new ImmutableDataset<>(new LibSVMDataSource<>(testingPath, outputFactory, zeroIndexed, maxFeatureID), train.getFeatureIDMap(), train.getOutputIDInfo(), false);
            break;
        case TEXT:
            // 
            // Using a simple Java break iterator to generate ngram features.
            TextFeatureExtractor<T> extractor;
            if (hashDim > 0) {
                extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), ngram, termCounting, hashDim));
            } else {
                extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), ngram, termCounting));
            }
            TextDataSource<T> trainSource = new SimpleTextDataSource<>(trainingPath, outputFactory, extractor);
            train = new MutableDataset<>(trainSource);
            if (minCount > 0) {
                logger.info("Removing features that occur fewer than " + minCount + " times.");
                train = new MinimumCardinalityDataset<>(train, minCount);
            }
            logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
            logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
            TextDataSource<T> testSource = new SimpleTextDataSource<>(testingPath, outputFactory, extractor);
            test = new ImmutableDataset<>(testSource, train.getFeatureIDMap(), train.getOutputIDInfo(), false);
            break;
        case CSV:
            // Load the data using the simple CSV loader
            if (csvResponseName == null) {
                throw new IllegalArgumentException("Please supply a response column name");
            }
            separator = delimiter.value;
            CSVLoader<T> loader = new CSVLoader<>(separator, outputFactory);
            train = new MutableDataset<>(loader.loadDataSource(trainingPath, csvResponseName));
            logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
            logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
            test = new MutableDataset<>(loader.loadDataSource(testingPath, csvResponseName));
            break;
        case COLUMNAR:
            if (rowProcessor == null) {
                throw new IllegalArgumentException("Please supply a RowProcessor");
            }
            OutputFactory<?> rowOutputFactory = rowProcessor.getResponseProcessor().getOutputFactory();
            if (!rowOutputFactory.equals(outputFactory)) {
                throw new IllegalArgumentException("The RowProcessor doesn't use the same kind of OutputFactory as the one supplied. RowProcessor has " + rowOutputFactory.getClass().getSimpleName() + ", supplied " + outputFactory.getClass().getName());
            }
            // checked by the if statement above
            @SuppressWarnings("unchecked") RowProcessor<T> typedRowProcessor = (RowProcessor<T>) rowProcessor;
            separator = delimiter.value;
            train = new MutableDataset<>(new CSVDataSource<>(trainingPath, typedRowProcessor, true, separator, csvQuoteChar));
            logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
            logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
            test = new MutableDataset<>(new CSVDataSource<>(testingPath, typedRowProcessor, true, separator, csvQuoteChar));
            break;
        default:
            throw new IllegalArgumentException("Unsupported input format " + inputFormat);
    }
    logger.info(String.format("Loaded %d testing examples", test.size()));
    if (scaleFeatures) {
        logger.info("Fitting feature scaling");
        TransformationMap map = new TransformationMap(Collections.singletonList(new LinearScalingTransformation()));
        TransformerMap transformers = train.createTransformers(map, scaleIncZeros);
        logger.info("Applying scaling to training dataset");
        train = transformers.transformDataset(train);
        logger.info("Applying scaling to testing dataset");
        test = transformers.transformDataset(test);
    }
    return new Pair<>(train, test);
}
Also used : TransformerMap(org.tribuo.transform.TransformerMap) CSVDataSource(org.tribuo.data.csv.CSVDataSource) SimpleTextDataSource(org.tribuo.data.text.impl.SimpleTextDataSource) TransformationMap(org.tribuo.transform.TransformationMap) LinearScalingTransformation(org.tribuo.transform.transformations.LinearScalingTransformation) BufferedInputStream(java.io.BufferedInputStream) LibSVMDataSource(org.tribuo.datasource.LibSVMDataSource) RowProcessor(org.tribuo.data.columnar.RowProcessor) Pair(com.oracle.labs.mlrg.olcut.util.Pair) CSVLoader(org.tribuo.data.csv.CSVLoader) ImmutableDataset(org.tribuo.ImmutableDataset) Dataset(org.tribuo.Dataset) MinimumCardinalityDataset(org.tribuo.dataset.MinimumCardinalityDataset) MutableDataset(org.tribuo.MutableDataset) FileInputStream(java.io.FileInputStream) BreakIteratorTokenizer(org.tribuo.util.tokens.impl.BreakIteratorTokenizer) TokenPipeline(org.tribuo.data.text.impl.TokenPipeline) ObjectInputStream(java.io.ObjectInputStream)

Aggregations

CSVLoader (org.tribuo.data.csv.CSVLoader)9 Path (java.nio.file.Path)8 Test (org.junit.jupiter.api.Test)7 HashSet (java.util.HashSet)5 File (java.io.File)4 CSVSaver (org.tribuo.data.csv.CSVSaver)4 RegressionFactory (org.tribuo.regression.RegressionFactory)4 Regressor (org.tribuo.regression.Regressor)4 MutableDataset (org.tribuo.MutableDataset)3 MultiLabel (org.tribuo.multilabel.MultiLabel)3 MultiLabelFactory (org.tribuo.multilabel.MultiLabelFactory)3 Pair (com.oracle.labs.mlrg.olcut.util.Pair)2 BufferedInputStream (java.io.BufferedInputStream)2 FileInputStream (java.io.FileInputStream)2 ObjectInputStream (java.io.ObjectInputStream)2 Dataset (org.tribuo.Dataset)2 Feature (org.tribuo.Feature)2 ImmutableDataset (org.tribuo.ImmutableDataset)2 SimpleTextDataSource (org.tribuo.data.text.impl.SimpleTextDataSource)2 TokenPipeline (org.tribuo.data.text.impl.TokenPipeline)2