use of com.oracle.labs.mlrg.olcut.util.MutableDouble in project tribuo by oracle.
the class MutableRegressionInfo method observe.
@Override
public void observe(Regressor output) {
if (output == RegressionFactory.UNKNOWN_REGRESSOR) {
unknownCount++;
} else {
if (overallCount != 0) {
// Validate that the dimensions in this regressor are the same as the ones already observed.
String[] names = output.getNames();
if (names.length != countMap.size()) {
throw new IllegalArgumentException("Expected this Regressor to contain " + countMap.size() + " dimensions, found " + names.length);
}
for (String name : names) {
if (!countMap.containsKey(name)) {
throw new IllegalArgumentException("Regressor contains unexpected dimension named '" + name + "'");
}
}
}
for (Regressor.DimensionTuple r : output) {
String name = r.getName();
double value = r.getValue();
// Update max and min
minMap.merge(name, new MutableDouble(value), (a, b) -> a.doubleValue() < b.doubleValue() ? a : b);
maxMap.merge(name, new MutableDouble(value), (a, b) -> a.doubleValue() > b.doubleValue() ? a : b);
// Update count
MutableLong countValue = countMap.computeIfAbsent(name, k -> new MutableLong());
countValue.increment();
// Update mean
MutableDouble meanValue = meanMap.computeIfAbsent(name, k -> new MutableDouble());
double delta = value - meanValue.doubleValue();
meanValue.increment(delta / countValue.longValue());
// Update running sum of squares
double delta2 = value - meanValue.doubleValue();
MutableDouble sumSquaresValue = sumSquaresMap.computeIfAbsent(name, k -> new MutableDouble());
sumSquaresValue.increment(delta * delta2);
}
overallCount++;
}
}
use of com.oracle.labs.mlrg.olcut.util.MutableDouble in project tribuo by oracle.
the class RegressionInfo method getDomain.
/**
* Returns a set containing a Regressor for each dimension with the minimum value observed.
* @return A set of Regressors, each with one active dimension.
*/
@Override
public Set<Regressor> getDomain() {
TreeSet<DimensionTuple> outputs = new TreeSet<>(Comparator.comparing(DimensionTuple::getName));
for (Map.Entry<String, MutableDouble> e : minMap.entrySet()) {
outputs.add(new DimensionTuple(e.getKey(), e.getValue().doubleValue()));
}
// DimensionTuple is a subtype of Regressor, and this set is immutable.
@SuppressWarnings("unchecked") SortedSet<Regressor> setOutputs = (SortedSet<Regressor>) (SortedSet) Collections.unmodifiableSortedSet(outputs);
return setOutputs;
}
use of com.oracle.labs.mlrg.olcut.util.MutableDouble in project tribuo by oracle.
the class XGBoostModel method getTopFeatures.
@Override
public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
try {
int maxFeatures = n < 0 ? featureIDMap.size() : n;
Map<String, List<Pair<String, Double>>> map = new HashMap<>();
for (int i = 0; i < models.size(); i++) {
Booster model = models.get(i);
Map<String, MutableDouble> outputMap = new HashMap<>();
Map<String, Integer> xgboostMap = model.getFeatureScore("");
for (Map.Entry<String, Integer> f : xgboostMap.entrySet()) {
int id = Integer.parseInt(f.getKey().substring(1));
String name = featureIDMap.get(id).getName();
MutableDouble curVal = outputMap.computeIfAbsent(name, (k) -> new MutableDouble());
curVal.increment(f.getValue());
}
Comparator<Pair<String, Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB()));
PriorityQueue<Pair<String, Double>> q = new PriorityQueue<>(maxFeatures, comparator);
for (Map.Entry<String, MutableDouble> e : outputMap.entrySet()) {
Pair<String, Double> cur = new Pair<>(e.getKey(), e.getValue().doubleValue());
if (q.size() < maxFeatures) {
q.offer(cur);
} else if (comparator.compare(cur, q.peek()) > 0) {
q.poll();
q.offer(cur);
}
}
List<Pair<String, Double>> list = new ArrayList<>();
while (q.size() > 0) {
list.add(q.poll());
}
Collections.reverse(list);
if (models.size() == 1) {
map.put(Model.ALL_OUTPUTS, list);
} else {
String dimensionName = outputIDInfo.getOutput(i).toString();
map.put(dimensionName, list);
}
}
return map;
} catch (XGBoostError e) {
logger.log(Level.SEVERE, "XGBoost threw an error", e);
return Collections.emptyMap();
}
}
Aggregations