Search in sources :

Example 1 with MutableDouble

use of com.oracle.labs.mlrg.olcut.util.MutableDouble in project tribuo by oracle.

the class MutableRegressionInfo method observe.

@Override
public void observe(Regressor output) {
    if (output == RegressionFactory.UNKNOWN_REGRESSOR) {
        unknownCount++;
    } else {
        if (overallCount != 0) {
            // Validate that the dimensions in this regressor are the same as the ones already observed.
            String[] names = output.getNames();
            if (names.length != countMap.size()) {
                throw new IllegalArgumentException("Expected this Regressor to contain " + countMap.size() + " dimensions, found " + names.length);
            }
            for (String name : names) {
                if (!countMap.containsKey(name)) {
                    throw new IllegalArgumentException("Regressor contains unexpected dimension named '" + name + "'");
                }
            }
        }
        for (Regressor.DimensionTuple r : output) {
            String name = r.getName();
            double value = r.getValue();
            // Update max and min
            minMap.merge(name, new MutableDouble(value), (a, b) -> a.doubleValue() < b.doubleValue() ? a : b);
            maxMap.merge(name, new MutableDouble(value), (a, b) -> a.doubleValue() > b.doubleValue() ? a : b);
            // Update count
            MutableLong countValue = countMap.computeIfAbsent(name, k -> new MutableLong());
            countValue.increment();
            // Update mean
            MutableDouble meanValue = meanMap.computeIfAbsent(name, k -> new MutableDouble());
            double delta = value - meanValue.doubleValue();
            meanValue.increment(delta / countValue.longValue());
            // Update running sum of squares
            double delta2 = value - meanValue.doubleValue();
            MutableDouble sumSquaresValue = sumSquaresMap.computeIfAbsent(name, k -> new MutableDouble());
            sumSquaresValue.increment(delta * delta2);
        }
        overallCount++;
    }
}
Also used : MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong) MutableDouble(com.oracle.labs.mlrg.olcut.util.MutableDouble)

Example 2 with MutableDouble

use of com.oracle.labs.mlrg.olcut.util.MutableDouble in project tribuo by oracle.

the class RegressionInfo method getDomain.

/**
 * Returns a set containing a Regressor for each dimension with the minimum value observed.
 * @return A set of Regressors, each with one active dimension.
 */
@Override
public Set<Regressor> getDomain() {
    TreeSet<DimensionTuple> outputs = new TreeSet<>(Comparator.comparing(DimensionTuple::getName));
    for (Map.Entry<String, MutableDouble> e : minMap.entrySet()) {
        outputs.add(new DimensionTuple(e.getKey(), e.getValue().doubleValue()));
    }
    // DimensionTuple is a subtype of Regressor, and this set is immutable.
    @SuppressWarnings("unchecked") SortedSet<Regressor> setOutputs = (SortedSet<Regressor>) (SortedSet) Collections.unmodifiableSortedSet(outputs);
    return setOutputs;
}
Also used : TreeSet(java.util.TreeSet) MutableDouble(com.oracle.labs.mlrg.olcut.util.MutableDouble) LinkedHashMap(java.util.LinkedHashMap) TreeMap(java.util.TreeMap) Map(java.util.Map) SortedSet(java.util.SortedSet) DimensionTuple(org.tribuo.regression.Regressor.DimensionTuple)

Example 3 with MutableDouble

use of com.oracle.labs.mlrg.olcut.util.MutableDouble in project tribuo by oracle.

the class XGBoostModel method getTopFeatures.

@Override
public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
    try {
        int maxFeatures = n < 0 ? featureIDMap.size() : n;
        Map<String, List<Pair<String, Double>>> map = new HashMap<>();
        for (int i = 0; i < models.size(); i++) {
            Booster model = models.get(i);
            Map<String, MutableDouble> outputMap = new HashMap<>();
            Map<String, Integer> xgboostMap = model.getFeatureScore("");
            for (Map.Entry<String, Integer> f : xgboostMap.entrySet()) {
                int id = Integer.parseInt(f.getKey().substring(1));
                String name = featureIDMap.get(id).getName();
                MutableDouble curVal = outputMap.computeIfAbsent(name, (k) -> new MutableDouble());
                curVal.increment(f.getValue());
            }
            Comparator<Pair<String, Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB()));
            PriorityQueue<Pair<String, Double>> q = new PriorityQueue<>(maxFeatures, comparator);
            for (Map.Entry<String, MutableDouble> e : outputMap.entrySet()) {
                Pair<String, Double> cur = new Pair<>(e.getKey(), e.getValue().doubleValue());
                if (q.size() < maxFeatures) {
                    q.offer(cur);
                } else if (comparator.compare(cur, q.peek()) > 0) {
                    q.poll();
                    q.offer(cur);
                }
            }
            List<Pair<String, Double>> list = new ArrayList<>();
            while (q.size() > 0) {
                list.add(q.poll());
            }
            Collections.reverse(list);
            if (models.size() == 1) {
                map.put(Model.ALL_OUTPUTS, list);
            } else {
                String dimensionName = outputIDInfo.getOutput(i).toString();
                map.put(dimensionName, list);
            }
        }
        return map;
    } catch (XGBoostError e) {
        logger.log(Level.SEVERE, "XGBoost threw an error", e);
        return Collections.emptyMap();
    }
}
Also used : HashMap(java.util.HashMap) MutableDouble(com.oracle.labs.mlrg.olcut.util.MutableDouble) Booster(ml.dmlc.xgboost4j.java.Booster) ArrayList(java.util.ArrayList) XGBoostError(ml.dmlc.xgboost4j.java.XGBoostError) PriorityQueue(java.util.PriorityQueue) MutableDouble(com.oracle.labs.mlrg.olcut.util.MutableDouble) ArrayList(java.util.ArrayList) List(java.util.List) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) HashMap(java.util.HashMap) Map(java.util.Map) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Aggregations

MutableDouble (com.oracle.labs.mlrg.olcut.util.MutableDouble)3 Map (java.util.Map)2 MutableLong (com.oracle.labs.mlrg.olcut.util.MutableLong)1 Pair (com.oracle.labs.mlrg.olcut.util.Pair)1 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 LinkedHashMap (java.util.LinkedHashMap)1 List (java.util.List)1 PriorityQueue (java.util.PriorityQueue)1 SortedSet (java.util.SortedSet)1 TreeMap (java.util.TreeMap)1 TreeSet (java.util.TreeSet)1 Booster (ml.dmlc.xgboost4j.java.Booster)1 XGBoostError (ml.dmlc.xgboost4j.java.XGBoostError)1 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)1 DimensionTuple (org.tribuo.regression.Regressor.DimensionTuple)1