use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.
the class MutableMultiLabelInfo method observe.
/**
* Throws IllegalStateException if the MultiLabel contains a Label which has a "," in it.
* <p>
* Such labels are disallowed. There should be an exception thrown when one is constructed
* too.
* @param output The observed output.
*/
@Override
public void observe(MultiLabel output) {
if (output == MultiLabelFactory.UNKNOWN_MULTILABEL) {
unknownCount++;
} else {
for (String label : output.getNameSet()) {
if (label.contains(",")) {
throw new IllegalStateException("MultiLabel cannot use a Label which contains ','. The supplied label was " + label + ".");
}
MutableLong value = labelCounts.computeIfAbsent(label, k -> new MutableLong());
labels.computeIfAbsent(label, MultiLabel::new);
value.increment();
}
totalCount++;
}
}
use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.
the class CategoricalInfo method observe.
@Override
protected void observe(double value) {
if (value != 0.0) {
super.observe(value);
if (valueCounts != null) {
MutableLong count = valueCounts.computeIfAbsent(value, k -> new MutableLong());
count.increment();
} else {
if (Double.isNaN(observedValue)) {
observedValue = value;
observedCount++;
} else if (Math.abs(value - observedValue) < COMPARISON_THRESHOLD) {
observedCount++;
} else {
// Observed two values for this CategoricalInfo, now it needs a HashMap.
valueCounts = new HashMap<>(4);
valueCounts.put(observedValue, new MutableLong(observedCount));
valueCounts.put(value, new MutableLong(1));
observedValue = Double.NaN;
observedCount = 0;
}
}
values = null;
}
}
use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.
the class MutableLabelInfo method observe.
@Override
public void observe(Label output) {
if (output == LabelFactory.UNKNOWN_LABEL) {
unknownCount++;
} else {
String label = output.getLabel();
MutableLong value = labelCounts.computeIfAbsent(label, k -> new MutableLong());
labels.computeIfAbsent(label, Label::new);
value.increment();
}
}
use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.
the class PairDistribution method constructFromLists.
/**
* Generates the counts for two vectors. Returns a PairDistribution containing the joint
* count, and the two marginal counts.
* @param <T1> Type of the first array.
* @param <T2> Type of the second array.
* @param first An array of values.
* @param second Another array of values.
* @return The joint counts and the two marginal counts.
*/
public static <T1, T2> PairDistribution<T1, T2> constructFromLists(List<T1> first, List<T2> second) {
LinkedHashMap<CachedPair<T1, T2>, MutableLong> abCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
LinkedHashMap<T1, MutableLong> aCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
LinkedHashMap<T2, MutableLong> bCountDist = new LinkedHashMap<>(InformationTheory.DEFAULT_MAP_SIZE);
if (first.size() == second.size()) {
long count = 0;
for (int i = 0; i < first.size(); i++) {
T1 a = first.get(i);
T2 b = second.get(i);
CachedPair<T1, T2> pair = new CachedPair<>(a, b);
MutableLong abCount = abCountDist.computeIfAbsent(pair, k -> new MutableLong());
abCount.increment();
MutableLong aCount = aCountDist.computeIfAbsent(a, k -> new MutableLong());
aCount.increment();
MutableLong bCount = bCountDist.computeIfAbsent(b, k -> new MutableLong());
bCount.increment();
count++;
}
return new PairDistribution<>(count, abCountDist, aCountDist, bCountDist);
} else {
throw new IllegalArgumentException("Counting requires arrays of the same length. first.size() = " + first.size() + ", second.size() = " + second.size());
}
}
use of com.oracle.labs.mlrg.olcut.util.MutableLong in project tribuo by oracle.
the class TripleDistribution method constructFromLists.
/**
* Constructs a TripleDistribution from three lists of the same length.
* <p>
* If they are not the same length it throws IllegalArgumentException.
* @param first The first list.
* @param second The second list.
* @param third The third list.
* @param <T1> The first type.
* @param <T2> The second type.
* @param <T3> The third type.
* @return The TripleDistribution.
*/
public static <T1, T2, T3> TripleDistribution<T1, T2, T3> constructFromLists(List<T1> first, List<T2> second, List<T3> third) {
Map<CachedTriple<T1, T2, T3>, MutableLong> jointCount = new LinkedHashMap<>(DEFAULT_MAP_SIZE);
Map<CachedPair<T1, T2>, MutableLong> abCount = new HashMap<>(DEFAULT_MAP_SIZE);
Map<CachedPair<T1, T3>, MutableLong> acCount = new HashMap<>(DEFAULT_MAP_SIZE);
Map<CachedPair<T2, T3>, MutableLong> bcCount = new HashMap<>(DEFAULT_MAP_SIZE);
Map<T1, MutableLong> aCount = new HashMap<>(DEFAULT_MAP_SIZE);
Map<T2, MutableLong> bCount = new HashMap<>(DEFAULT_MAP_SIZE);
Map<T3, MutableLong> cCount = new HashMap<>(DEFAULT_MAP_SIZE);
long count = first.size();
if ((first.size() == second.size()) && (first.size() == third.size())) {
for (int i = 0; i < first.size(); i++) {
T1 a = first.get(i);
T2 b = second.get(i);
T3 c = third.get(i);
CachedTriple<T1, T2, T3> abc = new CachedTriple<>(a, b, c);
CachedPair<T1, T2> ab = abc.getAB();
CachedPair<T1, T3> ac = abc.getAC();
CachedPair<T2, T3> bc = abc.getBC();
MutableLong abcCurCount = jointCount.computeIfAbsent(abc, k -> new MutableLong());
abcCurCount.increment();
MutableLong abCurCount = abCount.computeIfAbsent(ab, k -> new MutableLong());
abCurCount.increment();
MutableLong acCurCount = acCount.computeIfAbsent(ac, k -> new MutableLong());
acCurCount.increment();
MutableLong bcCurCount = bcCount.computeIfAbsent(bc, k -> new MutableLong());
bcCurCount.increment();
MutableLong aCurCount = aCount.computeIfAbsent(a, k -> new MutableLong());
aCurCount.increment();
MutableLong bCurCount = bCount.computeIfAbsent(b, k -> new MutableLong());
bCurCount.increment();
MutableLong cCurCount = cCount.computeIfAbsent(c, k -> new MutableLong());
cCurCount.increment();
}
return new TripleDistribution<>(count, jointCount, abCount, acCount, bcCount, aCount, bCount, cCount);
} else {
throw new IllegalArgumentException("Counting requires lists of the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", third.size() = " + third.size());
}
}
Aggregations