Search in sources :

Example 1 with TransformStatistics

use of org.tribuo.transform.TransformStatistics in project tribuo by oracle.

the class Dataset method createTransformers.

/**
 * Takes a {@link TransformationMap} and converts it into a {@link TransformerMap} by
 * observing all the values in this dataset.
 * <p>
 * Does not mutate the dataset, if you wish to apply the TransformerMap, use
 * {@link MutableDataset#transform} or {@link TransformerMap#transformDataset}.
 * <p>
 * TransformerMaps operate on feature values which are present, sparse values
 * are ignored and not transformed. If the zeros should be transformed, call
 * {@link MutableDataset#densify} on the datasets before applying a transformer.
 * See {@link org.tribuo.transform} for a more detailed discussion of densify and includeImplicitZeroFeatures.
 * <p>
 * Throws {@link IllegalArgumentException} if the TransformationMap object has
 * regexes which apply to multiple features.
 * @param transformations The transformations to fit.
 * @param includeImplicitZeroFeatures Use the implicit zero feature values to construct the transformations.
 * @return A TransformerMap which can apply the transformations to a dataset.
 */
public TransformerMap createTransformers(TransformationMap transformations, boolean includeImplicitZeroFeatures) {
    ArrayList<String> featureNames = new ArrayList<>(getFeatureMap().keySet());
    // Validate map by checking no regex applies to multiple features.
    logger.fine(String.format("Processing %d feature specific transforms", transformations.getFeatureTransformations().size()));
    Map<String, List<Transformation>> featureTransformations = new HashMap<>();
    for (Map.Entry<String, List<Transformation>> entry : transformations.getFeatureTransformations().entrySet()) {
        // Compile the regex.
        Pattern pattern = Pattern.compile(entry.getKey());
        // Check all the feature names
        for (String name : featureNames) {
            // If the regex matches
            if (pattern.matcher(name).matches()) {
                List<Transformation> oldTransformations = featureTransformations.put(name, entry.getValue());
                // See if there is already a transformation list for that name.
                if (oldTransformations != null) {
                    throw new IllegalArgumentException("Feature name '" + name + "' matches multiple regexes, at least one of which was '" + entry.getKey() + "'.");
                }
            }
        }
    }
    // Populate the feature transforms map.
    Map<String, Queue<TransformStatistics>> featureStats = new HashMap<>();
    // sparseCount tracks how many times a feature was not observed
    Map<String, MutableLong> sparseCount = new HashMap<>();
    for (Map.Entry<String, List<Transformation>> entry : featureTransformations.entrySet()) {
        // Create the queue of feature transformations for this feature
        Queue<TransformStatistics> l = new LinkedList<>();
        for (Transformation t : entry.getValue()) {
            l.add(t.createStats());
        }
        // Add the queue to the map for that feature
        featureStats.put(entry.getKey(), l);
        sparseCount.put(entry.getKey(), new MutableLong(data.size()));
    }
    if (!transformations.getGlobalTransformations().isEmpty()) {
        // Append all the global transformations
        int ntransform = featureNames.size();
        logger.fine(String.format("Starting %,d global transformations", ntransform));
        int ndone = 0;
        for (String v : featureNames) {
            // Create the queue of feature transformations for this feature
            Queue<TransformStatistics> l = featureStats.computeIfAbsent(v, (k) -> new LinkedList<>());
            for (Transformation t : transformations.getGlobalTransformations()) {
                l.add(t.createStats());
            }
            // Add the queue to the map for that feature
            featureStats.put(v, l);
            // Generate the sparse count initialised to the number of features.
            sparseCount.putIfAbsent(v, new MutableLong(data.size()));
            ndone++;
            if (logger.isLoggable(Level.FINE) && ndone % 10000 == 0) {
                logger.fine(String.format("Completed %,d of %,d global transformations", ndone, ntransform));
            }
        }
    }
    Map<String, List<Transformer>> output = new LinkedHashMap<>();
    Set<String> removeSet = new LinkedHashSet<>();
    boolean initialisedSparseCounts = false;
    // Iterate through the dataset max(transformations.length) times.
    while (!featureStats.isEmpty()) {
        for (Example<T> example : data) {
            for (Feature f : example) {
                if (featureStats.containsKey(f.getName())) {
                    if (!initialisedSparseCounts) {
                        sparseCount.get(f.getName()).decrement();
                    }
                    List<Transformer> curTransformers = output.get(f.getName());
                    // Apply all current transformations
                    double fValue = TransformerMap.applyTransformerList(f.getValue(), curTransformers);
                    // Observe the transformed value
                    featureStats.get(f.getName()).peek().observeValue(fValue);
                }
            }
        }
        // Sparse counts are updated (this could be protected by an if statement)
        initialisedSparseCounts = true;
        removeSet.clear();
        // Emit the new transformers onto the end of the list in the output map.
        for (Map.Entry<String, Queue<TransformStatistics>> entry : featureStats.entrySet()) {
            TransformStatistics currentStats = entry.getValue().poll();
            if (includeImplicitZeroFeatures) {
                // Observe all the sparse feature values
                int unobservedFeatures = sparseCount.get(entry.getKey()).intValue();
                currentStats.observeSparse(unobservedFeatures);
            }
            // Get the transformer list for that feature (if absent)
            List<Transformer> l = output.computeIfAbsent(entry.getKey(), (k) -> new ArrayList<>());
            // Generate the transformer and add it to the appropriate list.
            l.add(currentStats.generateTransformer());
            // If the queue is empty, remove that feature, ensuring that featureStats is eventually empty.
            if (entry.getValue().isEmpty()) {
                removeSet.add(entry.getKey());
            }
        }
        // Remove the features with empty queues.
        for (String s : removeSet) {
            featureStats.remove(s);
        }
    }
    return new TransformerMap(output, getProvenance(), transformations.getProvenance());
}
Also used : LinkedHashSet(java.util.LinkedHashSet) TransformerMap(org.tribuo.transform.TransformerMap) Transformation(org.tribuo.transform.Transformation) Transformer(org.tribuo.transform.Transformer) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) List(java.util.List) Queue(java.util.Queue) TransformStatistics(org.tribuo.transform.TransformStatistics) Pattern(java.util.regex.Pattern) LinkedList(java.util.LinkedList) MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map) TransformerMap(org.tribuo.transform.TransformerMap) TransformationMap(org.tribuo.transform.TransformationMap)

Aggregations

MutableLong (com.oracle.labs.mlrg.olcut.util.MutableLong)1 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 LinkedHashMap (java.util.LinkedHashMap)1 LinkedHashSet (java.util.LinkedHashSet)1 LinkedList (java.util.LinkedList)1 List (java.util.List)1 Map (java.util.Map)1 Queue (java.util.Queue)1 Pattern (java.util.regex.Pattern)1 TransformStatistics (org.tribuo.transform.TransformStatistics)1 Transformation (org.tribuo.transform.Transformation)1 TransformationMap (org.tribuo.transform.TransformationMap)1 Transformer (org.tribuo.transform.Transformer)1 TransformerMap (org.tribuo.transform.TransformerMap)1