Search in sources :

Example 1 with TransformationMap

use of org.tribuo.transform.TransformationMap in project tribuo by oracle.

the class DataOptions method load.

/**
 * Loads the training and testing data from {@link #trainingPath} and {@link #testingPath}
 * according to the other parameters specified in this class.
 * @param outputFactory The output factory to use to process the inputs.
 * @param <T> The dataset output type.
 * @return A pair containing the training and testing datasets. The training dataset is element 'A' and the
 * testing dataset is element 'B'.
 * @throws IOException If the paths could not be loaded.
 */
public <T extends Output<T>> Pair<Dataset<T>, Dataset<T>> load(OutputFactory<T> outputFactory) throws IOException {
    logger.info(String.format("Loading data from %s", trainingPath));
    Dataset<T> train;
    Dataset<T> test;
    char separator;
    switch(inputFormat) {
        case SERIALIZED:
            // 
            // Load Tribuo serialised datasets.
            logger.info("Deserialising dataset from " + trainingPath);
            try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(trainingPath.toFile())));
                ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(testingPath.toFile())))) {
                @SuppressWarnings("unchecked") Dataset<T> tmp = (Dataset<T>) ois.readObject();
                train = tmp;
                if (minCount > 0) {
                    logger.info("Found " + train.getFeatureIDMap().size() + " features");
                    logger.info("Removing features that occur fewer than " + minCount + " times.");
                    train = new MinimumCardinalityDataset<>(train, minCount);
                }
                logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
                logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
                @SuppressWarnings("unchecked") Dataset<T> deserTest = (Dataset<T>) oits.readObject();
                test = new ImmutableDataset<>(deserTest, deserTest.getSourceProvenance(), deserTest.getOutputFactory(), train.getFeatureIDMap(), train.getOutputIDInfo(), true);
            } catch (ClassNotFoundException e) {
                throw new IllegalArgumentException("Unknown class in serialised files", e);
            }
            break;
        case LIBSVM:
            // 
            // Load the libsvm text-based data format.
            LibSVMDataSource<T> trainSVMSource = new LibSVMDataSource<>(trainingPath, outputFactory);
            train = new MutableDataset<>(trainSVMSource);
            boolean zeroIndexed = trainSVMSource.isZeroIndexed();
            int maxFeatureID = trainSVMSource.getMaxFeatureID();
            if (minCount > 0) {
                logger.info("Removing features that occur fewer than " + minCount + " times.");
                train = new MinimumCardinalityDataset<>(train, minCount);
            }
            logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
            logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
            test = new ImmutableDataset<>(new LibSVMDataSource<>(testingPath, outputFactory, zeroIndexed, maxFeatureID), train.getFeatureIDMap(), train.getOutputIDInfo(), false);
            break;
        case TEXT:
            // 
            // Using a simple Java break iterator to generate ngram features.
            TextFeatureExtractor<T> extractor;
            if (hashDim > 0) {
                extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), ngram, termCounting, hashDim));
            } else {
                extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), ngram, termCounting));
            }
            TextDataSource<T> trainSource = new SimpleTextDataSource<>(trainingPath, outputFactory, extractor);
            train = new MutableDataset<>(trainSource);
            if (minCount > 0) {
                logger.info("Removing features that occur fewer than " + minCount + " times.");
                train = new MinimumCardinalityDataset<>(train, minCount);
            }
            logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
            logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
            TextDataSource<T> testSource = new SimpleTextDataSource<>(testingPath, outputFactory, extractor);
            test = new ImmutableDataset<>(testSource, train.getFeatureIDMap(), train.getOutputIDInfo(), false);
            break;
        case CSV:
            // Load the data using the simple CSV loader
            if (csvResponseName == null) {
                throw new IllegalArgumentException("Please supply a response column name");
            }
            separator = delimiter.value;
            CSVLoader<T> loader = new CSVLoader<>(separator, outputFactory);
            train = new MutableDataset<>(loader.loadDataSource(trainingPath, csvResponseName));
            logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
            logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
            test = new MutableDataset<>(loader.loadDataSource(testingPath, csvResponseName));
            break;
        case COLUMNAR:
            if (rowProcessor == null) {
                throw new IllegalArgumentException("Please supply a RowProcessor");
            }
            OutputFactory<?> rowOutputFactory = rowProcessor.getResponseProcessor().getOutputFactory();
            if (!rowOutputFactory.equals(outputFactory)) {
                throw new IllegalArgumentException("The RowProcessor doesn't use the same kind of OutputFactory as the one supplied. RowProcessor has " + rowOutputFactory.getClass().getSimpleName() + ", supplied " + outputFactory.getClass().getName());
            }
            // checked by the if statement above
            @SuppressWarnings("unchecked") RowProcessor<T> typedRowProcessor = (RowProcessor<T>) rowProcessor;
            separator = delimiter.value;
            train = new MutableDataset<>(new CSVDataSource<>(trainingPath, typedRowProcessor, true, separator, csvQuoteChar));
            logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
            logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
            test = new MutableDataset<>(new CSVDataSource<>(testingPath, typedRowProcessor, true, separator, csvQuoteChar));
            break;
        default:
            throw new IllegalArgumentException("Unsupported input format " + inputFormat);
    }
    logger.info(String.format("Loaded %d testing examples", test.size()));
    if (scaleFeatures) {
        logger.info("Fitting feature scaling");
        TransformationMap map = new TransformationMap(Collections.singletonList(new LinearScalingTransformation()));
        TransformerMap transformers = train.createTransformers(map, scaleIncZeros);
        logger.info("Applying scaling to training dataset");
        train = transformers.transformDataset(train);
        logger.info("Applying scaling to testing dataset");
        test = transformers.transformDataset(test);
    }
    return new Pair<>(train, test);
}
Also used : TransformerMap(org.tribuo.transform.TransformerMap) CSVDataSource(org.tribuo.data.csv.CSVDataSource) SimpleTextDataSource(org.tribuo.data.text.impl.SimpleTextDataSource) TransformationMap(org.tribuo.transform.TransformationMap) LinearScalingTransformation(org.tribuo.transform.transformations.LinearScalingTransformation) BufferedInputStream(java.io.BufferedInputStream) LibSVMDataSource(org.tribuo.datasource.LibSVMDataSource) RowProcessor(org.tribuo.data.columnar.RowProcessor) Pair(com.oracle.labs.mlrg.olcut.util.Pair) CSVLoader(org.tribuo.data.csv.CSVLoader) ImmutableDataset(org.tribuo.ImmutableDataset) Dataset(org.tribuo.Dataset) MinimumCardinalityDataset(org.tribuo.dataset.MinimumCardinalityDataset) MutableDataset(org.tribuo.MutableDataset) FileInputStream(java.io.FileInputStream) BreakIteratorTokenizer(org.tribuo.util.tokens.impl.BreakIteratorTokenizer) TokenPipeline(org.tribuo.data.text.impl.TokenPipeline) ObjectInputStream(java.io.ObjectInputStream)

Example 2 with TransformationMap

use of org.tribuo.transform.TransformationMap in project tribuo by oracle.

the class Dataset method createTransformers.

/**
 * Takes a {@link TransformationMap} and converts it into a {@link TransformerMap} by
 * observing all the values in this dataset.
 * <p>
 * Does not mutate the dataset, if you wish to apply the TransformerMap, use
 * {@link MutableDataset#transform} or {@link TransformerMap#transformDataset}.
 * <p>
 * TransformerMaps operate on feature values which are present, sparse values
 * are ignored and not transformed. If the zeros should be transformed, call
 * {@link MutableDataset#densify} on the datasets before applying a transformer.
 * See {@link org.tribuo.transform} for a more detailed discussion of densify and includeImplicitZeroFeatures.
 * <p>
 * Throws {@link IllegalArgumentException} if the TransformationMap object has
 * regexes which apply to multiple features.
 * @param transformations The transformations to fit.
 * @param includeImplicitZeroFeatures Use the implicit zero feature values to construct the transformations.
 * @return A TransformerMap which can apply the transformations to a dataset.
 */
public TransformerMap createTransformers(TransformationMap transformations, boolean includeImplicitZeroFeatures) {
    ArrayList<String> featureNames = new ArrayList<>(getFeatureMap().keySet());
    // Validate map by checking no regex applies to multiple features.
    logger.fine(String.format("Processing %d feature specific transforms", transformations.getFeatureTransformations().size()));
    Map<String, List<Transformation>> featureTransformations = new HashMap<>();
    for (Map.Entry<String, List<Transformation>> entry : transformations.getFeatureTransformations().entrySet()) {
        // Compile the regex.
        Pattern pattern = Pattern.compile(entry.getKey());
        // Check all the feature names
        for (String name : featureNames) {
            // If the regex matches
            if (pattern.matcher(name).matches()) {
                List<Transformation> oldTransformations = featureTransformations.put(name, entry.getValue());
                // See if there is already a transformation list for that name.
                if (oldTransformations != null) {
                    throw new IllegalArgumentException("Feature name '" + name + "' matches multiple regexes, at least one of which was '" + entry.getKey() + "'.");
                }
            }
        }
    }
    // Populate the feature transforms map.
    Map<String, Queue<TransformStatistics>> featureStats = new HashMap<>();
    // sparseCount tracks how many times a feature was not observed
    Map<String, MutableLong> sparseCount = new HashMap<>();
    for (Map.Entry<String, List<Transformation>> entry : featureTransformations.entrySet()) {
        // Create the queue of feature transformations for this feature
        Queue<TransformStatistics> l = new LinkedList<>();
        for (Transformation t : entry.getValue()) {
            l.add(t.createStats());
        }
        // Add the queue to the map for that feature
        featureStats.put(entry.getKey(), l);
        sparseCount.put(entry.getKey(), new MutableLong(data.size()));
    }
    if (!transformations.getGlobalTransformations().isEmpty()) {
        // Append all the global transformations
        int ntransform = featureNames.size();
        logger.fine(String.format("Starting %,d global transformations", ntransform));
        int ndone = 0;
        for (String v : featureNames) {
            // Create the queue of feature transformations for this feature
            Queue<TransformStatistics> l = featureStats.computeIfAbsent(v, (k) -> new LinkedList<>());
            for (Transformation t : transformations.getGlobalTransformations()) {
                l.add(t.createStats());
            }
            // Add the queue to the map for that feature
            featureStats.put(v, l);
            // Generate the sparse count initialised to the number of features.
            sparseCount.putIfAbsent(v, new MutableLong(data.size()));
            ndone++;
            if (logger.isLoggable(Level.FINE) && ndone % 10000 == 0) {
                logger.fine(String.format("Completed %,d of %,d global transformations", ndone, ntransform));
            }
        }
    }
    Map<String, List<Transformer>> output = new LinkedHashMap<>();
    Set<String> removeSet = new LinkedHashSet<>();
    boolean initialisedSparseCounts = false;
    // Iterate through the dataset max(transformations.length) times.
    while (!featureStats.isEmpty()) {
        for (Example<T> example : data) {
            for (Feature f : example) {
                if (featureStats.containsKey(f.getName())) {
                    if (!initialisedSparseCounts) {
                        sparseCount.get(f.getName()).decrement();
                    }
                    List<Transformer> curTransformers = output.get(f.getName());
                    // Apply all current transformations
                    double fValue = TransformerMap.applyTransformerList(f.getValue(), curTransformers);
                    // Observe the transformed value
                    featureStats.get(f.getName()).peek().observeValue(fValue);
                }
            }
        }
        // Sparse counts are updated (this could be protected by an if statement)
        initialisedSparseCounts = true;
        removeSet.clear();
        // Emit the new transformers onto the end of the list in the output map.
        for (Map.Entry<String, Queue<TransformStatistics>> entry : featureStats.entrySet()) {
            TransformStatistics currentStats = entry.getValue().poll();
            if (includeImplicitZeroFeatures) {
                // Observe all the sparse feature values
                int unobservedFeatures = sparseCount.get(entry.getKey()).intValue();
                currentStats.observeSparse(unobservedFeatures);
            }
            // Get the transformer list for that feature (if absent)
            List<Transformer> l = output.computeIfAbsent(entry.getKey(), (k) -> new ArrayList<>());
            // Generate the transformer and add it to the appropriate list.
            l.add(currentStats.generateTransformer());
            // If the queue is empty, remove that feature, ensuring that featureStats is eventually empty.
            if (entry.getValue().isEmpty()) {
                removeSet.add(entry.getKey());
            }
        }
        // Remove the features with empty queues.
        for (String s : removeSet) {
            featureStats.remove(s);
        }
    }
    return new TransformerMap(output, getProvenance(), transformations.getProvenance());
}
Also used : LinkedHashSet(java.util.LinkedHashSet) TransformerMap(org.tribuo.transform.TransformerMap) Transformation(org.tribuo.transform.Transformation) Transformer(org.tribuo.transform.Transformer) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) List(java.util.List) Queue(java.util.Queue) TransformStatistics(org.tribuo.transform.TransformStatistics) Pattern(java.util.regex.Pattern) LinkedList(java.util.LinkedList) MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map) TransformerMap(org.tribuo.transform.TransformerMap) TransformationMap(org.tribuo.transform.TransformationMap)

Aggregations

TransformationMap (org.tribuo.transform.TransformationMap)2 TransformerMap (org.tribuo.transform.TransformerMap)2 MutableLong (com.oracle.labs.mlrg.olcut.util.MutableLong)1 Pair (com.oracle.labs.mlrg.olcut.util.Pair)1 BufferedInputStream (java.io.BufferedInputStream)1 FileInputStream (java.io.FileInputStream)1 ObjectInputStream (java.io.ObjectInputStream)1 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 LinkedHashMap (java.util.LinkedHashMap)1 LinkedHashSet (java.util.LinkedHashSet)1 LinkedList (java.util.LinkedList)1 List (java.util.List)1 Map (java.util.Map)1 Queue (java.util.Queue)1 Pattern (java.util.regex.Pattern)1 Dataset (org.tribuo.Dataset)1 ImmutableDataset (org.tribuo.ImmutableDataset)1 MutableDataset (org.tribuo.MutableDataset)1 RowProcessor (org.tribuo.data.columnar.RowProcessor)1