Search in sources :

Example 11 with MutableLong

use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.

the class MutableMockMultiOutputInfo method observe.

/**
 * Throws IllegalStateException if the MockMultiOutput contains a Label which has a "," in it.
 *
 * Such labels are disallowed. There should be an exception thrown when one is constructed
 * too.
 * @param output The observed output.
 */
@Override
public void observe(MockMultiOutput output) {
    if (output == MockMultiOutputFactory.UNKNOWN_MULTILABEL) {
        unknownCount++;
    } else {
        for (String label : output.getNameSet()) {
            if (label.contains(",")) {
                throw new IllegalStateException("MockMultiOutput cannot use a Label which contains ','. The supplied label was " + label + ".");
            }
            MutableLong value = labelCounts.computeIfAbsent(label, k -> new MutableLong());
            labels.computeIfAbsent(label, MockMultiOutput::new);
            value.increment();
        }
        totalCount++;
    }
}
Also used : MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong)

Example 12 with MutableLong

use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.

the class Dataset method createTransformers.

/**
 * Takes a {@link TransformationMap} and converts it into a {@link TransformerMap} by
 * observing all the values in this dataset.
 * <p>
 * Does not mutate the dataset, if you wish to apply the TransformerMap, use
 * {@link MutableDataset#transform} or {@link TransformerMap#transformDataset}.
 * <p>
 * TransformerMaps operate on feature values which are present, sparse values
 * are ignored and not transformed. If the zeros should be transformed, call
 * {@link MutableDataset#densify} on the datasets before applying a transformer.
 * See {@link org.tribuo.transform} for a more detailed discussion of densify and includeImplicitZeroFeatures.
 * <p>
 * Throws {@link IllegalArgumentException} if the TransformationMap object has
 * regexes which apply to multiple features.
 * @param transformations The transformations to fit.
 * @param includeImplicitZeroFeatures Use the implicit zero feature values to construct the transformations.
 * @return A TransformerMap which can apply the transformations to a dataset.
 */
public TransformerMap createTransformers(TransformationMap transformations, boolean includeImplicitZeroFeatures) {
    ArrayList<String> featureNames = new ArrayList<>(getFeatureMap().keySet());
    // Validate map by checking no regex applies to multiple features.
    logger.fine(String.format("Processing %d feature specific transforms", transformations.getFeatureTransformations().size()));
    Map<String, List<Transformation>> featureTransformations = new HashMap<>();
    for (Map.Entry<String, List<Transformation>> entry : transformations.getFeatureTransformations().entrySet()) {
        // Compile the regex.
        Pattern pattern = Pattern.compile(entry.getKey());
        // Check all the feature names
        for (String name : featureNames) {
            // If the regex matches
            if (pattern.matcher(name).matches()) {
                List<Transformation> oldTransformations = featureTransformations.put(name, entry.getValue());
                // See if there is already a transformation list for that name.
                if (oldTransformations != null) {
                    throw new IllegalArgumentException("Feature name '" + name + "' matches multiple regexes, at least one of which was '" + entry.getKey() + "'.");
                }
            }
        }
    }
    // Populate the feature transforms map.
    Map<String, Queue<TransformStatistics>> featureStats = new HashMap<>();
    // sparseCount tracks how many times a feature was not observed
    Map<String, MutableLong> sparseCount = new HashMap<>();
    for (Map.Entry<String, List<Transformation>> entry : featureTransformations.entrySet()) {
        // Create the queue of feature transformations for this feature
        Queue<TransformStatistics> l = new LinkedList<>();
        for (Transformation t : entry.getValue()) {
            l.add(t.createStats());
        }
        // Add the queue to the map for that feature
        featureStats.put(entry.getKey(), l);
        sparseCount.put(entry.getKey(), new MutableLong(data.size()));
    }
    if (!transformations.getGlobalTransformations().isEmpty()) {
        // Append all the global transformations
        int ntransform = featureNames.size();
        logger.fine(String.format("Starting %,d global transformations", ntransform));
        int ndone = 0;
        for (String v : featureNames) {
            // Create the queue of feature transformations for this feature
            Queue<TransformStatistics> l = featureStats.computeIfAbsent(v, (k) -> new LinkedList<>());
            for (Transformation t : transformations.getGlobalTransformations()) {
                l.add(t.createStats());
            }
            // Add the queue to the map for that feature
            featureStats.put(v, l);
            // Generate the sparse count initialised to the number of features.
            sparseCount.putIfAbsent(v, new MutableLong(data.size()));
            ndone++;
            if (logger.isLoggable(Level.FINE) && ndone % 10000 == 0) {
                logger.fine(String.format("Completed %,d of %,d global transformations", ndone, ntransform));
            }
        }
    }
    Map<String, List<Transformer>> output = new LinkedHashMap<>();
    Set<String> removeSet = new LinkedHashSet<>();
    boolean initialisedSparseCounts = false;
    // Iterate through the dataset max(transformations.length) times.
    while (!featureStats.isEmpty()) {
        for (Example<T> example : data) {
            for (Feature f : example) {
                if (featureStats.containsKey(f.getName())) {
                    if (!initialisedSparseCounts) {
                        sparseCount.get(f.getName()).decrement();
                    }
                    List<Transformer> curTransformers = output.get(f.getName());
                    // Apply all current transformations
                    double fValue = TransformerMap.applyTransformerList(f.getValue(), curTransformers);
                    // Observe the transformed value
                    featureStats.get(f.getName()).peek().observeValue(fValue);
                }
            }
        }
        // Sparse counts are updated (this could be protected by an if statement)
        initialisedSparseCounts = true;
        removeSet.clear();
        // Emit the new transformers onto the end of the list in the output map.
        for (Map.Entry<String, Queue<TransformStatistics>> entry : featureStats.entrySet()) {
            TransformStatistics currentStats = entry.getValue().poll();
            if (includeImplicitZeroFeatures) {
                // Observe all the sparse feature values
                int unobservedFeatures = sparseCount.get(entry.getKey()).intValue();
                currentStats.observeSparse(unobservedFeatures);
            }
            // Get the transformer list for that feature (if absent)
            List<Transformer> l = output.computeIfAbsent(entry.getKey(), (k) -> new ArrayList<>());
            // Generate the transformer and add it to the appropriate list.
            l.add(currentStats.generateTransformer());
            // If the queue is empty, remove that feature, ensuring that featureStats is eventually empty.
            if (entry.getValue().isEmpty()) {
                removeSet.add(entry.getKey());
            }
        }
        // Remove the features with empty queues.
        for (String s : removeSet) {
            featureStats.remove(s);
        }
    }
    return new TransformerMap(output, getProvenance(), transformations.getProvenance());
}
Also used : LinkedHashSet(java.util.LinkedHashSet) TransformerMap(org.tribuo.transform.TransformerMap) Transformation(org.tribuo.transform.Transformation) Transformer(org.tribuo.transform.Transformer) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) List(java.util.List) Queue(java.util.Queue) TransformStatistics(org.tribuo.transform.TransformStatistics) Pattern(java.util.regex.Pattern) LinkedList(java.util.LinkedList) MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map) TransformerMap(org.tribuo.transform.TransformerMap) TransformationMap(org.tribuo.transform.TransformationMap)

Example 13 with MutableLong

use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.

the class KMeansTrainer method train.

@Override
public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance, int invocationCount) {
    // Creates a new local RNG and adds one to the invocation count.
    TrainerProvenance trainerProvenance;
    SplittableRandom localRNG;
    synchronized (this) {
        if (invocationCount != INCREMENT_INVOCATION_COUNT) {
            setInvocationCount(invocationCount);
        }
        localRNG = rng.split();
        trainerProvenance = getProvenance();
        trainInvocationCounter++;
    }
    ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
    int[] oldCentre = new int[examples.size()];
    SGDVector[] data = new SGDVector[examples.size()];
    double[] weights = new double[examples.size()];
    int n = 0;
    for (Example<ClusterID> example : examples) {
        weights[n] = example.getWeight();
        if (example.size() == featureMap.size()) {
            data[n] = DenseVector.createDenseVector(example, featureMap, false);
        } else {
            data[n] = SparseVector.createSparseVector(example, featureMap, false);
        }
        oldCentre[n] = -1;
        n++;
    }
    DenseVector[] centroidVectors;
    switch(initialisationType) {
        case RANDOM:
            centroidVectors = initialiseRandomCentroids(centroids, featureMap, localRNG);
            break;
        case PLUSPLUS:
            centroidVectors = initialisePlusPlusCentroids(centroids, data, localRNG, distType);
            break;
        default:
            throw new IllegalStateException("Unknown initialisation" + initialisationType);
    }
    Map<Integer, List<Integer>> clusterAssignments = new HashMap<>();
    boolean parallel = numThreads > 1;
    for (int i = 0; i < centroids; i++) {
        clusterAssignments.put(i, parallel ? Collections.synchronizedList(new ArrayList<>()) : new ArrayList<>());
    }
    AtomicInteger changeCounter = new AtomicInteger(0);
    Consumer<IntAndVector> eStepFunc = (IntAndVector e) -> {
        double minDist = Double.POSITIVE_INFINITY;
        int clusterID = -1;
        int id = e.idx;
        SGDVector vector = e.vector;
        for (int j = 0; j < centroids; j++) {
            DenseVector cluster = centroidVectors[j];
            double distance = DistanceType.getDistance(cluster, vector, distType);
            if (distance < minDist) {
                minDist = distance;
                clusterID = j;
            }
        }
        clusterAssignments.get(clusterID).add(id);
        if (oldCentre[id] != clusterID) {
            // Changed the centroid of this vector.
            oldCentre[id] = clusterID;
            changeCounter.incrementAndGet();
        }
    };
    boolean converged = false;
    ForkJoinPool fjp = null;
    try {
        if (parallel) {
            if (System.getSecurityManager() == null) {
                fjp = new ForkJoinPool(numThreads);
            } else {
                fjp = new ForkJoinPool(numThreads, THREAD_FACTORY, null, false);
            }
        }
        for (int i = 0; (i < iterations) && !converged; i++) {
            logger.log(Level.FINE, "Beginning iteration " + i);
            changeCounter.set(0);
            for (Entry<Integer, List<Integer>> e : clusterAssignments.entrySet()) {
                e.getValue().clear();
            }
            // E step
            Stream<SGDVector> vecStream = Arrays.stream(data);
            Stream<Integer> intStream = IntStream.range(0, data.length).boxed();
            Stream<IntAndVector> zipStream = StreamUtil.zip(intStream, vecStream, IntAndVector::new);
            if (parallel) {
                Stream<IntAndVector> parallelZipStream = StreamUtil.boundParallelism(zipStream.parallel());
                try {
                    fjp.submit(() -> parallelZipStream.forEach(eStepFunc)).get();
                } catch (InterruptedException | ExecutionException e) {
                    throw new RuntimeException("Parallel execution failed", e);
                }
            } else {
                zipStream.forEach(eStepFunc);
            }
            logger.log(Level.FINE, "E step completed. " + changeCounter.get() + " words updated.");
            mStep(fjp, centroidVectors, clusterAssignments, data, weights);
            logger.log(Level.INFO, "Iteration " + i + " completed. " + changeCounter.get() + " examples updated.");
            if (changeCounter.get() == 0) {
                converged = true;
                logger.log(Level.INFO, "K-Means converged at iteration " + i);
            }
        }
    } finally {
        if (fjp != null) {
            fjp.shutdown();
        }
    }
    Map<Integer, MutableLong> counts = new HashMap<>();
    for (Entry<Integer, List<Integer>> e : clusterAssignments.entrySet()) {
        counts.put(e.getKey(), new MutableLong(e.getValue().size()));
    }
    ImmutableOutputInfo<ClusterID> outputMap = new ImmutableClusteringInfo(counts);
    ModelProvenance provenance = new ModelProvenance(KMeansModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
    return new KMeansModel("k-means-model", provenance, featureMap, outputMap, centroidVectors, distType);
}
Also used : ClusterID(org.tribuo.clustering.ClusterID) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) SGDVector(org.tribuo.math.la.SGDVector) ArrayList(java.util.ArrayList) List(java.util.List) ExecutionException(java.util.concurrent.ExecutionException) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) SplittableRandom(java.util.SplittableRandom) DenseVector(org.tribuo.math.la.DenseVector) ForkJoinPool(java.util.concurrent.ForkJoinPool) ImmutableClusteringInfo(org.tribuo.clustering.ImmutableClusteringInfo)

Example 14 with MutableLong

use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.

the class ClusteringFactory method constructInfoForExternalModel.

/**
 * Unlike the other info types, clustering directly uses the integer IDs as the stored value,
 * so this mapping discards the cluster IDs and just uses the supplied integers.
 * @param mapping The mapping to use.
 * @return An {@link ImmutableOutputInfo} for the clustering.
 */
@Override
public ImmutableOutputInfo<ClusterID> constructInfoForExternalModel(Map<ClusterID, Integer> mapping) {
    // Validate inputs are dense
    OutputFactory.validateMapping(mapping);
    Map<Integer, MutableLong> countsMap = new HashMap<>();
    for (Map.Entry<ClusterID, Integer> e : mapping.entrySet()) {
        countsMap.put(e.getValue(), new MutableLong(1));
    }
    return new ImmutableClusteringInfo(countsMap);
}
Also used : MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong) HashMap(java.util.HashMap) Map(java.util.Map) HashMap(java.util.HashMap)

Example 15 with MutableLong

use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.

the class MutableClusteringInfo method observe.

@Override
public void observe(ClusterID output) {
    if (output == ClusteringFactory.UNASSIGNED_CLUSTER_ID) {
        unknownCount++;
    } else {
        int id = output.getID();
        MutableLong value = clusterCounts.computeIfAbsent(id, k -> new MutableLong());
        value.increment();
    }
}
Also used : MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong)

Aggregations

MutableLong (com.oracle.labs.mlrg.olcut.util.MutableLong)15 HashMap (java.util.HashMap)6 ArrayList (java.util.ArrayList)3 LinkedHashMap (java.util.LinkedHashMap)3 List (java.util.List)3 Map (java.util.Map)2 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)2 ClusterID (org.tribuo.clustering.ClusterID)2 ImmutableClusteringInfo (org.tribuo.clustering.ImmutableClusteringInfo)2 DenseVector (org.tribuo.math.la.DenseVector)2 SGDVector (org.tribuo.math.la.SGDVector)2 ModelProvenance (org.tribuo.provenance.ModelProvenance)2 TrainerProvenance (org.tribuo.provenance.TrainerProvenance)2 MutableDouble (com.oracle.labs.mlrg.olcut.util.MutableDouble)1 LinkedHashSet (java.util.LinkedHashSet)1 LinkedList (java.util.LinkedList)1 Queue (java.util.Queue)1 SplittableRandom (java.util.SplittableRandom)1 ExecutionException (java.util.concurrent.ExecutionException)1 ForkJoinPool (java.util.concurrent.ForkJoinPool)1