Search in sources :

Example 1 with SplitNode

use of org.tribuo.common.tree.SplitNode in project tribuo by oracle.

the class IndependentRegressionTreeModel method getExcuse.

@Override
public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
    SparseVector vec = SparseVector.createSparseVector(example, featureIDMap, false);
    if (vec.numActiveElements() == 0) {
        return Optional.empty();
    }
    List<String> list = new ArrayList<>();
    List<Prediction<Regressor>> predList = new ArrayList<>();
    Map<String, List<Pair<String, Double>>> map = new HashMap<>();
    for (Map.Entry<String, Node<Regressor>> e : roots.entrySet()) {
        list.clear();
        // 
        // Ensures we handle collisions correctly
        Node<Regressor> oldNode = e.getValue();
        Node<Regressor> curNode = e.getValue();
        while (curNode != null) {
            oldNode = curNode;
            if (oldNode instanceof SplitNode) {
                SplitNode<?> node = (SplitNode<?>) curNode;
                list.add(featureIDMap.get(node.getFeatureID()).getName());
            }
            curNode = oldNode.getNextNode(vec);
        }
        // 
        // oldNode must be a LeafNode.
        predList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example));
        List<Pair<String, Double>> pairs = new ArrayList<>();
        int i = list.size() + 1;
        for (String s : list) {
            pairs.add(new Pair<>(s, i + 0.0));
            i--;
        }
        map.put(e.getKey(), pairs);
    }
    Prediction<Regressor> combinedPrediction = combine(predList);
    return Optional.of(new Excuse<>(example, combinedPrediction, map));
}
Also used : HashMap(java.util.HashMap) Prediction(org.tribuo.Prediction) SplitNode(org.tribuo.common.tree.SplitNode) LeafNode(org.tribuo.common.tree.LeafNode) Node(org.tribuo.common.tree.Node) ArrayList(java.util.ArrayList) SparseVector(org.tribuo.math.la.SparseVector) SplitNode(org.tribuo.common.tree.SplitNode) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) List(java.util.List) Regressor(org.tribuo.regression.Regressor) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) HashMap(java.util.HashMap) Map(java.util.Map) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Aggregations

Pair (com.oracle.labs.mlrg.olcut.util.Pair)1 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 LinkedList (java.util.LinkedList)1 List (java.util.List)1 Map (java.util.Map)1 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)1 Prediction (org.tribuo.Prediction)1 LeafNode (org.tribuo.common.tree.LeafNode)1 Node (org.tribuo.common.tree.Node)1 SplitNode (org.tribuo.common.tree.SplitNode)1 SparseVector (org.tribuo.math.la.SparseVector)1 Regressor (org.tribuo.regression.Regressor)1