Search in sources :

Example 6 with MutableLong

use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.

the class MutableMultiLabelInfo method observe.

/**
 * Throws IllegalStateException if the MultiLabel contains a Label which has a "," in it.
 * <p>
 * Such labels are disallowed. There should be an exception thrown when one is constructed
 * too.
 * @param output The observed output.
 */
@Override
public void observe(MultiLabel output) {
    if (output == MultiLabelFactory.UNKNOWN_MULTILABEL) {
        unknownCount++;
    } else {
        for (String label : output.getNameSet()) {
            if (label.contains(",")) {
                throw new IllegalStateException("MultiLabel cannot use a Label which contains ','. The supplied label was " + label + ".");
            }
            MutableLong value = labelCounts.computeIfAbsent(label, k -> new MutableLong());
            labels.computeIfAbsent(label, MultiLabel::new);
            value.increment();
        }
        totalCount++;
    }
}
Also used : MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong)

Example 7 with MutableLong

use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.

the class CategoricalInfo method observe.

@Override
protected void observe(double value) {
    if (value != 0.0) {
        super.observe(value);
        if (valueCounts != null) {
            MutableLong count = valueCounts.computeIfAbsent(value, k -> new MutableLong());
            count.increment();
        } else {
            if (Double.isNaN(observedValue)) {
                observedValue = value;
                observedCount++;
            } else if (Math.abs(value - observedValue) < COMPARISON_THRESHOLD) {
                observedCount++;
            } else {
                // Observed two values for this CategoricalInfo, now it needs a HashMap.
                valueCounts = new HashMap<>(4);
                valueCounts.put(observedValue, new MutableLong(observedCount));
                valueCounts.put(value, new MutableLong(1));
                observedValue = Double.NaN;
                observedCount = 0;
            }
        }
        values = null;
    }
}
Also used : MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong) HashMap(java.util.HashMap)

Example 8 with MutableLong

use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.

the class MutableLabelInfo method observe.

@Override
public void observe(Label output) {
    if (output == LabelFactory.UNKNOWN_LABEL) {
        unknownCount++;
    } else {
        String label = output.getLabel();
        MutableLong value = labelCounts.computeIfAbsent(label, k -> new MutableLong());
        labels.computeIfAbsent(label, Label::new);
        value.increment();
    }
}
Also used : MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong)

Example 9 with MutableLong

use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.

the class PairDistribution method constructFromLists.

/**
 * Generates the counts for two vectors. Returns a PairDistribution containing the joint
 * count, and the two marginal counts.
 * @param <T1> Type of the first array.
 * @param <T2> Type of the second array.
 * @param first An array of values.
 * @param second Another array of values.
 * @return The joint counts and the two marginal counts.
 */
public static <T1, T2> PairDistribution<T1, T2> constructFromLists(List<T1> first, List<T2> second) {
    LinkedHashMap<CachedPair<T1, T2>, MutableLong> abCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
    LinkedHashMap<T1, MutableLong> aCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
    LinkedHashMap<T2, MutableLong> bCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
    if (first.size() == second.size()) {
        long count = 0;
        for (int i = 0; i < first.size(); i++) {
            T1 a = first.get(i);
            T2 b = second.get(i);
            CachedPair<T1, T2> pair = new CachedPair<>(a, b);
            MutableLong abCount = abCountDist.computeIfAbsent(pair, k -> new MutableLong());
            abCount.increment();
            MutableLong aCount = aCountDist.computeIfAbsent(a, k -> new MutableLong());
            aCount.increment();
            MutableLong bCount = bCountDist.computeIfAbsent(b, k -> new MutableLong());
            bCount.increment();
            count++;
        }
        return new PairDistribution<>(count, abCountDist, aCountDist, bCountDist);
    } else {
        throw new IllegalArgumentException("Counting requires arrays of the same length. first.size() = " + first.size() + ", second.size() = " + second.size());
    }
}
Also used : LinkedHashMap(java.util.LinkedHashMap) MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong)

Example 10 with MutableLong

use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.

the class TripleDistribution method constructFromLists.

/**
 * Constructs a TripleDistribution from three lists of the same length.
 * <p>
 * If they are not the same length it throws IllegalArgumentException.
 * @param first The first list.
 * @param second The second list.
 * @param third The third list.
 * @param <T1> The first type.
 * @param <T2> The second type.
 * @param <T3> The third type.
 * @return The TripleDistribution.
 */
public static <T1, T2, T3> TripleDistribution<T1, T2, T3> constructFromLists(List<T1> first, List<T2> second, List<T3> third) {
    Map<CachedTriple<T1, T2, T3>, MutableLong> jointCount = new LinkedHashMap<>(DEFAULT_MAP_SIZE);
    Map<CachedPair<T1, T2>, MutableLong> abCount = new HashMap<>(DEFAULT_MAP_SIZE);
    Map<CachedPair<T1, T3>, MutableLong> acCount = new HashMap<>(DEFAULT_MAP_SIZE);
    Map<CachedPair<T2, T3>, MutableLong> bcCount = new HashMap<>(DEFAULT_MAP_SIZE);
    Map<T1, MutableLong> aCount = new HashMap<>(DEFAULT_MAP_SIZE);
    Map<T2, MutableLong> bCount = new HashMap<>(DEFAULT_MAP_SIZE);
    Map<T3, MutableLong> cCount = new HashMap<>(DEFAULT_MAP_SIZE);
    long count = first.size();
    if ((first.size() == second.size()) && (first.size() == third.size())) {
        for (int i = 0; i < first.size(); i++) {
            T1 a = first.get(i);
            T2 b = second.get(i);
            T3 c = third.get(i);
            CachedTriple<T1, T2, T3> abc = new CachedTriple<>(a, b, c);
            CachedPair<T1, T2> ab = abc.getAB();
            CachedPair<T1, T3> ac = abc.getAC();
            CachedPair<T2, T3> bc = abc.getBC();
            MutableLong abcCurCount = jointCount.computeIfAbsent(abc, k -> new MutableLong());
            abcCurCount.increment();
            MutableLong abCurCount = abCount.computeIfAbsent(ab, k -> new MutableLong());
            abCurCount.increment();
            MutableLong acCurCount = acCount.computeIfAbsent(ac, k -> new MutableLong());
            acCurCount.increment();
            MutableLong bcCurCount = bcCount.computeIfAbsent(bc, k -> new MutableLong());
            bcCurCount.increment();
            MutableLong aCurCount = aCount.computeIfAbsent(a, k -> new MutableLong());
            aCurCount.increment();
            MutableLong bCurCount = bCount.computeIfAbsent(b, k -> new MutableLong());
            bCurCount.increment();
            MutableLong cCurCount = cCount.computeIfAbsent(c, k -> new MutableLong());
            cCurCount.increment();
        }
        return new TripleDistribution<>(count, jointCount, abCount, acCount, bcCount, aCount, bCount, cCount);
    } else {
        throw new IllegalArgumentException("Counting requires lists of the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", third.size() = " + third.size());
    }
}
Also used : LinkedHashMap(java.util.LinkedHashMap) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong)

Aggregations

MutableLong (com.oracle.labs.mlrg.olcut.util.MutableLong)15 HashMap (java.util.HashMap)6 ArrayList (java.util.ArrayList)3 LinkedHashMap (java.util.LinkedHashMap)3 List (java.util.List)3 Map (java.util.Map)2 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)2 ClusterID (org.tribuo.clustering.ClusterID)2 ImmutableClusteringInfo (org.tribuo.clustering.ImmutableClusteringInfo)2 DenseVector (org.tribuo.math.la.DenseVector)2 SGDVector (org.tribuo.math.la.SGDVector)2 ModelProvenance (org.tribuo.provenance.ModelProvenance)2 TrainerProvenance (org.tribuo.provenance.TrainerProvenance)2 MutableDouble (com.oracle.labs.mlrg.olcut.util.MutableDouble)1 LinkedHashSet (java.util.LinkedHashSet)1 LinkedList (java.util.LinkedList)1 Queue (java.util.Queue)1 SplittableRandom (java.util.SplittableRandom)1 ExecutionException (java.util.concurrent.ExecutionException)1 ForkJoinPool (java.util.concurrent.ForkJoinPool)1