Search in sources :

Example 11 with MutableLong

use of in project tribuo by oracle.

the class MutableMockMultiOutputInfo method observe.

 * Throws IllegalStateException if the MockMultiOutput contains a Label which has a "," in it.
 * Such labels are disallowed. There should be an exception thrown when one is constructed
 * too.
 * @param output The observed output.
public void observe(MockMultiOutput output) {
    if (output == MockMultiOutputFactory.UNKNOWN_MULTILABEL) {
    } else {
        for (String label : output.getNameSet()) {
            if (label.contains(",")) {
                throw new IllegalStateException("MockMultiOutput cannot use a Label which contains ','. The supplied label was " + label + ".");
            MutableLong value = labelCounts.computeIfAbsent(label, k -> new MutableLong());
            labels.computeIfAbsent(label, MockMultiOutput::new);
Also used : MutableLong(

Example 12 with MutableLong

use of in project tribuo by oracle.

the class Dataset method createTransformers.

 * Takes a {@link TransformationMap} and converts it into a {@link TransformerMap} by
 * observing all the values in this dataset.
 * <p>
 * Does not mutate the dataset, if you wish to apply the TransformerMap, use
 * {@link MutableDataset#transform} or {@link TransformerMap#transformDataset}.
 * <p>
 * TransformerMaps operate on feature values which are present, sparse values
 * are ignored and not transformed. If the zeros should be transformed, call
 * {@link MutableDataset#densify} on the datasets before applying a transformer.
 * See {@link org.tribuo.transform} for a more detailed discussion of densify and includeImplicitZeroFeatures.
 * <p>
 * Throws {@link IllegalArgumentException} if the TransformationMap object has
 * regexes which apply to multiple features.
 * @param transformations The transformations to fit.
 * @param includeImplicitZeroFeatures Use the implicit zero feature values to construct the transformations.
 * @return A TransformerMap which can apply the transformations to a dataset.
public TransformerMap createTransformers(TransformationMap transformations, boolean includeImplicitZeroFeatures) {
    ArrayList<String> featureNames = new ArrayList<>(getFeatureMap().keySet());
    // Validate map by checking no regex applies to multiple features.
    logger.fine(String.format("Processing %d feature specific transforms", transformations.getFeatureTransformations().size()));
    Map<String, List<Transformation>> featureTransformations = new HashMap<>();
    for (Map.Entry<String, List<Transformation>> entry : transformations.getFeatureTransformations().entrySet()) {
        // Compile the regex.
        Pattern pattern = Pattern.compile(entry.getKey());
        // Check all the feature names
        for (String name : featureNames) {
            // If the regex matches
            if (pattern.matcher(name).matches()) {
                List<Transformation> oldTransformations = featureTransformations.put(name, entry.getValue());
                // See if there is already a transformation list for that name.
                if (oldTransformations != null) {
                    throw new IllegalArgumentException("Feature name '" + name + "' matches multiple regexes, at least one of which was '" + entry.getKey() + "'.");
    // Populate the feature transforms map.
    Map<String, Queue<TransformStatistics>> featureStats = new HashMap<>();
    // sparseCount tracks how many times a feature was not observed
    Map<String, MutableLong> sparseCount = new HashMap<>();
    for (Map.Entry<String, List<Transformation>> entry : featureTransformations.entrySet()) {
        // Create the queue of feature transformations for this feature
        Queue<TransformStatistics> l = new LinkedList<>();
        for (Transformation t : entry.getValue()) {
        // Add the queue to the map for that feature
        featureStats.put(entry.getKey(), l);
        sparseCount.put(entry.getKey(), new MutableLong(data.size()));
    if (!transformations.getGlobalTransformations().isEmpty()) {
        // Append all the global transformations
        int ntransform = featureNames.size();
        logger.fine(String.format("Starting %,d global transformations", ntransform));
        int ndone = 0;
        for (String v : featureNames) {
            // Create the queue of feature transformations for this feature
            Queue<TransformStatistics> l = featureStats.computeIfAbsent(v, (k) -> new LinkedList<>());
            for (Transformation t : transformations.getGlobalTransformations()) {
            // Add the queue to the map for that feature
            featureStats.put(v, l);
            // Generate the sparse count initialised to the number of features.
            sparseCount.putIfAbsent(v, new MutableLong(data.size()));
            if (logger.isLoggable(Level.FINE) && ndone % 10000 == 0) {
                logger.fine(String.format("Completed %,d of %,d global transformations", ndone, ntransform));
    Map<String, List<Transformer>> output = new LinkedHashMap<>();
    Set<String> removeSet = new LinkedHashSet<>();
    boolean initialisedSparseCounts = false;
    // Iterate through the dataset max(transformations.length) times.
    while (!featureStats.isEmpty()) {
        for (Example<T> example : data) {
            for (Feature f : example) {
                if (featureStats.containsKey(f.getName())) {
                    if (!initialisedSparseCounts) {
                    List<Transformer> curTransformers = output.get(f.getName());
                    // Apply all current transformations
                    double fValue = TransformerMap.applyTransformerList(f.getValue(), curTransformers);
                    // Observe the transformed value
        // Sparse counts are updated (this could be protected by an if statement)
        initialisedSparseCounts = true;
        // Emit the new transformers onto the end of the list in the output map.
        for (Map.Entry<String, Queue<TransformStatistics>> entry : featureStats.entrySet()) {
            TransformStatistics currentStats = entry.getValue().poll();
            if (includeImplicitZeroFeatures) {
                // Observe all the sparse feature values
                int unobservedFeatures = sparseCount.get(entry.getKey()).intValue();
            // Get the transformer list for that feature (if absent)
            List<Transformer> l = output.computeIfAbsent(entry.getKey(), (k) -> new ArrayList<>());
            // Generate the transformer and add it to the appropriate list.
            // If the queue is empty, remove that feature, ensuring that featureStats is eventually empty.
            if (entry.getValue().isEmpty()) {
        // Remove the features with empty queues.
        for (String s : removeSet) {
    return new TransformerMap(output, getProvenance(), transformations.getProvenance());
Also used : LinkedHashSet(java.util.LinkedHashSet) TransformerMap(org.tribuo.transform.TransformerMap) Transformation(org.tribuo.transform.Transformation) Transformer(org.tribuo.transform.Transformer) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) List(java.util.List) Queue(java.util.Queue) TransformStatistics(org.tribuo.transform.TransformStatistics) Pattern(java.util.regex.Pattern) LinkedList(java.util.LinkedList) MutableLong( HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map) TransformerMap(org.tribuo.transform.TransformerMap) TransformationMap(org.tribuo.transform.TransformationMap)

Example 13 with MutableLong

use of in project tribuo by oracle.

the class KMeansTrainer method train.

public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance, int invocationCount) {
    // Creates a new local RNG and adds one to the invocation count.
    TrainerProvenance trainerProvenance;
    SplittableRandom localRNG;
    synchronized (this) {
        if (invocationCount != INCREMENT_INVOCATION_COUNT) {
        localRNG = rng.split();
        trainerProvenance = getProvenance();
    ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
    int[] oldCentre = new int[examples.size()];
    SGDVector[] data = new SGDVector[examples.size()];
    double[] weights = new double[examples.size()];
    int n = 0;
    for (Example<ClusterID> example : examples) {
        weights[n] = example.getWeight();
        if (example.size() == featureMap.size()) {
            data[n] = DenseVector.createDenseVector(example, featureMap, false);
        } else {
            data[n] = SparseVector.createSparseVector(example, featureMap, false);
        oldCentre[n] = -1;
    DenseVector[] centroidVectors;
    switch(initialisationType) {
        case RANDOM:
            centroidVectors = initialiseRandomCentroids(centroids, featureMap, localRNG);
        case PLUSPLUS:
            centroidVectors = initialisePlusPlusCentroids(centroids, data, localRNG, distType);
            throw new IllegalStateException("Unknown initialisation" + initialisationType);
    Map<Integer, List<Integer>> clusterAssignments = new HashMap<>();
    boolean parallel = numThreads > 1;
    for (int i = 0; i < centroids; i++) {
        clusterAssignments.put(i, parallel ? Collections.synchronizedList(new ArrayList<>()) : new ArrayList<>());
    AtomicInteger changeCounter = new AtomicInteger(0);
    Consumer<IntAndVector> eStepFunc = (IntAndVector e) -> {
        double minDist = Double.POSITIVE_INFINITY;
        int clusterID = -1;
        int id = e.idx;
        SGDVector vector = e.vector;
        for (int j = 0; j < centroids; j++) {
            DenseVector cluster = centroidVectors[j];
            double distance = DistanceType.getDistance(cluster, vector, distType);
            if (distance < minDist) {
                minDist = distance;
                clusterID = j;
        if (oldCentre[id] != clusterID) {
            // Changed the centroid of this vector.
            oldCentre[id] = clusterID;
    boolean converged = false;
    ForkJoinPool fjp = null;
    try {
        if (parallel) {
            if (System.getSecurityManager() == null) {
                fjp = new ForkJoinPool(numThreads);
            } else {
                fjp = new ForkJoinPool(numThreads, THREAD_FACTORY, null, false);
        for (int i = 0; (i < iterations) && !converged; i++) {
            logger.log(Level.FINE, "Beginning iteration " + i);
            for (Entry<Integer, List<Integer>> e : clusterAssignments.entrySet()) {
            // E step
            Stream<SGDVector> vecStream =;
            Stream<Integer> intStream = IntStream.range(0, data.length).boxed();
            Stream<IntAndVector> zipStream =, vecStream, IntAndVector::new);
            if (parallel) {
                Stream<IntAndVector> parallelZipStream = StreamUtil.boundParallelism(zipStream.parallel());
                try {
                    fjp.submit(() -> parallelZipStream.forEach(eStepFunc)).get();
                } catch (InterruptedException | ExecutionException e) {
                    throw new RuntimeException("Parallel execution failed", e);
            } else {
            logger.log(Level.FINE, "E step completed. " + changeCounter.get() + " words updated.");
            mStep(fjp, centroidVectors, clusterAssignments, data, weights);
            logger.log(Level.INFO, "Iteration " + i + " completed. " + changeCounter.get() + " examples updated.");
            if (changeCounter.get() == 0) {
                converged = true;
                logger.log(Level.INFO, "K-Means converged at iteration " + i);
    } finally {
        if (fjp != null) {
    Map<Integer, MutableLong> counts = new HashMap<>();
    for (Entry<Integer, List<Integer>> e : clusterAssignments.entrySet()) {
        counts.put(e.getKey(), new MutableLong(e.getValue().size()));
    ImmutableOutputInfo<ClusterID> outputMap = new ImmutableClusteringInfo(counts);
    ModelProvenance provenance = new ModelProvenance(KMeansModel.class.getName(),, examples.getProvenance(), trainerProvenance, runProvenance);
    return new KMeansModel("k-means-model", provenance, featureMap, outputMap, centroidVectors, distType);
Also used : ClusterID(org.tribuo.clustering.ClusterID) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) SGDVector( ArrayList(java.util.ArrayList) List(java.util.List) ExecutionException(java.util.concurrent.ExecutionException) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) MutableLong( AtomicInteger(java.util.concurrent.atomic.AtomicInteger) SplittableRandom(java.util.SplittableRandom) DenseVector( ForkJoinPool(java.util.concurrent.ForkJoinPool) ImmutableClusteringInfo(org.tribuo.clustering.ImmutableClusteringInfo)

Example 14 with MutableLong

use of in project tribuo by oracle.

the class ClusteringFactory method constructInfoForExternalModel.

 * Unlike the other info types, clustering directly uses the integer IDs as the stored value,
 * so this mapping discards the cluster IDs and just uses the supplied integers.
 * @param mapping The mapping to use.
 * @return An {@link ImmutableOutputInfo} for the clustering.
public ImmutableOutputInfo<ClusterID> constructInfoForExternalModel(Map<ClusterID, Integer> mapping) {
    // Validate inputs are dense
    Map<Integer, MutableLong> countsMap = new HashMap<>();
    for (Map.Entry<ClusterID, Integer> e : mapping.entrySet()) {
        countsMap.put(e.getValue(), new MutableLong(1));
    return new ImmutableClusteringInfo(countsMap);
Also used : MutableLong( HashMap(java.util.HashMap) Map(java.util.Map) HashMap(java.util.HashMap)

Example 15 with MutableLong

use of in project tribuo by oracle.

the class MutableClusteringInfo method observe.

public void observe(ClusterID output) {
    if (output == ClusteringFactory.UNASSIGNED_CLUSTER_ID) {
    } else {
        int id = output.getID();
        MutableLong value = clusterCounts.computeIfAbsent(id, k -> new MutableLong());
Also used : MutableLong(


MutableLong ( HashMap (java.util.HashMap)6 ArrayList (java.util.ArrayList)3 LinkedHashMap (java.util.LinkedHashMap)3 List (java.util.List)3 Map (java.util.Map)2 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)2 ClusterID (org.tribuo.clustering.ClusterID)2 ImmutableClusteringInfo (org.tribuo.clustering.ImmutableClusteringInfo)2 DenseVector ( SGDVector ( ModelProvenance (org.tribuo.provenance.ModelProvenance)2 TrainerProvenance (org.tribuo.provenance.TrainerProvenance)2 MutableDouble ( LinkedHashSet (java.util.LinkedHashSet)1 LinkedList (java.util.LinkedList)1 Queue (java.util.Queue)1 SplittableRandom (java.util.SplittableRandom)1 ExecutionException (java.util.concurrent.ExecutionException)1 ForkJoinPool (java.util.concurrent.ForkJoinPool)1