Search in sources :

Example 1 with LinearScalingTransformation

use of org.tribuo.transform.transformations.LinearScalingTransformation in project tribuo by oracle.

the class DataOptions method load.

/**
 * Loads the training and testing data from {@link #trainingPath} and {@link #testingPath}
 * according to the other parameters specified in this class.
 * @param outputFactory The output factory to use to process the inputs.
 * @param <T> The dataset output type.
 * @return A pair containing the training and testing datasets. The training dataset is element 'A' and the
 * testing dataset is element 'B'.
 * @throws IOException If the paths could not be loaded.
 */
public <T extends Output<T>> Pair<Dataset<T>, Dataset<T>> load(OutputFactory<T> outputFactory) throws IOException {
    logger.info(String.format("Loading data from %s", trainingPath));
    Dataset<T> train;
    Dataset<T> test;
    char separator;
    switch(inputFormat) {
        case SERIALIZED:
            // 
            // Load Tribuo serialised datasets.
            logger.info("Deserialising dataset from " + trainingPath);
            try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(trainingPath.toFile())));
                ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(testingPath.toFile())))) {
                @SuppressWarnings("unchecked") Dataset<T> tmp = (Dataset<T>) ois.readObject();
                train = tmp;
                if (minCount > 0) {
                    logger.info("Found " + train.getFeatureIDMap().size() + " features");
                    logger.info("Removing features that occur fewer than " + minCount + " times.");
                    train = new MinimumCardinalityDataset<>(train, minCount);
                }
                logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
                logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
                @SuppressWarnings("unchecked") Dataset<T> deserTest = (Dataset<T>) oits.readObject();
                test = new ImmutableDataset<>(deserTest, deserTest.getSourceProvenance(), deserTest.getOutputFactory(), train.getFeatureIDMap(), train.getOutputIDInfo(), true);
            } catch (ClassNotFoundException e) {
                throw new IllegalArgumentException("Unknown class in serialised files", e);
            }
            break;
        case LIBSVM:
            // 
            // Load the libsvm text-based data format.
            LibSVMDataSource<T> trainSVMSource = new LibSVMDataSource<>(trainingPath, outputFactory);
            train = new MutableDataset<>(trainSVMSource);
            boolean zeroIndexed = trainSVMSource.isZeroIndexed();
            int maxFeatureID = trainSVMSource.getMaxFeatureID();
            if (minCount > 0) {
                logger.info("Removing features that occur fewer than " + minCount + " times.");
                train = new MinimumCardinalityDataset<>(train, minCount);
            }
            logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
            logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
            test = new ImmutableDataset<>(new LibSVMDataSource<>(testingPath, outputFactory, zeroIndexed, maxFeatureID), train.getFeatureIDMap(), train.getOutputIDInfo(), false);
            break;
        case TEXT:
            // 
            // Using a simple Java break iterator to generate ngram features.
            TextFeatureExtractor<T> extractor;
            if (hashDim > 0) {
                extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), ngram, termCounting, hashDim));
            } else {
                extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), ngram, termCounting));
            }
            TextDataSource<T> trainSource = new SimpleTextDataSource<>(trainingPath, outputFactory, extractor);
            train = new MutableDataset<>(trainSource);
            if (minCount > 0) {
                logger.info("Removing features that occur fewer than " + minCount + " times.");
                train = new MinimumCardinalityDataset<>(train, minCount);
            }
            logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
            logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
            TextDataSource<T> testSource = new SimpleTextDataSource<>(testingPath, outputFactory, extractor);
            test = new ImmutableDataset<>(testSource, train.getFeatureIDMap(), train.getOutputIDInfo(), false);
            break;
        case CSV:
            // Load the data using the simple CSV loader
            if (csvResponseName == null) {
                throw new IllegalArgumentException("Please supply a response column name");
            }
            separator = delimiter.value;
            CSVLoader<T> loader = new CSVLoader<>(separator, outputFactory);
            train = new MutableDataset<>(loader.loadDataSource(trainingPath, csvResponseName));
            logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
            logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
            test = new MutableDataset<>(loader.loadDataSource(testingPath, csvResponseName));
            break;
        case COLUMNAR:
            if (rowProcessor == null) {
                throw new IllegalArgumentException("Please supply a RowProcessor");
            }
            OutputFactory<?> rowOutputFactory = rowProcessor.getResponseProcessor().getOutputFactory();
            if (!rowOutputFactory.equals(outputFactory)) {
                throw new IllegalArgumentException("The RowProcessor doesn't use the same kind of OutputFactory as the one supplied. RowProcessor has " + rowOutputFactory.getClass().getSimpleName() + ", supplied " + outputFactory.getClass().getName());
            }
            // checked by the if statement above
            @SuppressWarnings("unchecked") RowProcessor<T> typedRowProcessor = (RowProcessor<T>) rowProcessor;
            separator = delimiter.value;
            train = new MutableDataset<>(new CSVDataSource<>(trainingPath, typedRowProcessor, true, separator, csvQuoteChar));
            logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
            logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
            test = new MutableDataset<>(new CSVDataSource<>(testingPath, typedRowProcessor, true, separator, csvQuoteChar));
            break;
        default:
            throw new IllegalArgumentException("Unsupported input format " + inputFormat);
    }
    logger.info(String.format("Loaded %d testing examples", test.size()));
    if (scaleFeatures) {
        logger.info("Fitting feature scaling");
        TransformationMap map = new TransformationMap(Collections.singletonList(new LinearScalingTransformation()));
        TransformerMap transformers = train.createTransformers(map, scaleIncZeros);
        logger.info("Applying scaling to training dataset");
        train = transformers.transformDataset(train);
        logger.info("Applying scaling to testing dataset");
        test = transformers.transformDataset(test);
    }
    return new Pair<>(train, test);
}
Also used : TransformerMap(org.tribuo.transform.TransformerMap) CSVDataSource(org.tribuo.data.csv.CSVDataSource) SimpleTextDataSource(org.tribuo.data.text.impl.SimpleTextDataSource) TransformationMap(org.tribuo.transform.TransformationMap) LinearScalingTransformation(org.tribuo.transform.transformations.LinearScalingTransformation) BufferedInputStream(java.io.BufferedInputStream) LibSVMDataSource(org.tribuo.datasource.LibSVMDataSource) RowProcessor(org.tribuo.data.columnar.RowProcessor) Pair(com.oracle.labs.mlrg.olcut.util.Pair) CSVLoader(org.tribuo.data.csv.CSVLoader) ImmutableDataset(org.tribuo.ImmutableDataset) Dataset(org.tribuo.Dataset) MinimumCardinalityDataset(org.tribuo.dataset.MinimumCardinalityDataset) MutableDataset(org.tribuo.MutableDataset) FileInputStream(java.io.FileInputStream) BreakIteratorTokenizer(org.tribuo.util.tokens.impl.BreakIteratorTokenizer) TokenPipeline(org.tribuo.data.text.impl.TokenPipeline) ObjectInputStream(java.io.ObjectInputStream)

Aggregations

Pair (com.oracle.labs.mlrg.olcut.util.Pair)1 BufferedInputStream (java.io.BufferedInputStream)1 FileInputStream (java.io.FileInputStream)1 ObjectInputStream (java.io.ObjectInputStream)1 Dataset (org.tribuo.Dataset)1 ImmutableDataset (org.tribuo.ImmutableDataset)1 MutableDataset (org.tribuo.MutableDataset)1 RowProcessor (org.tribuo.data.columnar.RowProcessor)1 CSVDataSource (org.tribuo.data.csv.CSVDataSource)1 CSVLoader (org.tribuo.data.csv.CSVLoader)1 SimpleTextDataSource (org.tribuo.data.text.impl.SimpleTextDataSource)1 TokenPipeline (org.tribuo.data.text.impl.TokenPipeline)1 MinimumCardinalityDataset (org.tribuo.dataset.MinimumCardinalityDataset)1 LibSVMDataSource (org.tribuo.datasource.LibSVMDataSource)1 TransformationMap (org.tribuo.transform.TransformationMap)1 TransformerMap (org.tribuo.transform.TransformerMap)1 LinearScalingTransformation (org.tribuo.transform.transformations.LinearScalingTransformation)1 BreakIteratorTokenizer (org.tribuo.util.tokens.impl.BreakIteratorTokenizer)1