Search in sources :

Example 1 with Node

use of org.tribuo.common.tree.Node in project tribuo by oracle.

the class CARTRegressionTrainer method train.

@Override
public TreeModel<Regressor> train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance, int invocationCount) {
    if (examples.getOutputInfo().getUnknownCount() > 0) {
        throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
    }
    // Creates a new RNG, adds one to the invocation count.
    SplittableRandom localRNG;
    TrainerProvenance trainerProvenance;
    synchronized (this) {
        if (invocationCount != INCREMENT_INVOCATION_COUNT) {
            setInvocationCount(invocationCount);
        }
        localRNG = rng.split();
        trainerProvenance = getProvenance();
        trainInvocationCounter++;
    }
    ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
    ImmutableOutputInfo<Regressor> outputIDInfo = examples.getOutputIDInfo();
    Set<Regressor> domain = outputIDInfo.getDomain();
    int numFeaturesInSplit = Math.min(Math.round(fractionFeaturesInSplit * featureIDMap.size()), featureIDMap.size());
    int[] indices;
    int[] originalIndices = new int[featureIDMap.size()];
    for (int i = 0; i < originalIndices.length; i++) {
        originalIndices[i] = i;
    }
    if (numFeaturesInSplit != featureIDMap.size()) {
        indices = new int[numFeaturesInSplit];
    } else {
        indices = originalIndices;
    }
    float weightSum = 0.0f;
    for (Example<Regressor> e : examples) {
        weightSum += e.getWeight();
    }
    float scaledMinImpurityDecrease = getMinImpurityDecrease() * weightSum;
    AbstractTrainingNode.LeafDeterminer leafDeterminer = new AbstractTrainingNode.LeafDeterminer(maxDepth, minChildWeight, scaledMinImpurityDecrease);
    InvertedData data = RegressorTrainingNode.invertData(examples);
    Map<String, Node<Regressor>> nodeMap = new HashMap<>();
    for (Regressor r : domain) {
        String dimName = r.getNames()[0];
        int dimIdx = outputIDInfo.getID(r);
        AbstractTrainingNode<Regressor> root = new RegressorTrainingNode(impurity, data, dimIdx, dimName, examples.size(), featureIDMap, outputIDInfo, leafDeterminer);
        Deque<AbstractTrainingNode<Regressor>> queue = new ArrayDeque<>();
        queue.add(root);
        while (!queue.isEmpty()) {
            AbstractTrainingNode<Regressor> node = queue.poll();
            if ((node.getImpurity() > 0.0) && (node.getDepth() < maxDepth) && (node.getWeightSum() >= minChildWeight)) {
                if (numFeaturesInSplit != featureIDMap.size()) {
                    Util.randpermInPlace(originalIndices, localRNG);
                    System.arraycopy(originalIndices, 0, indices, 0, numFeaturesInSplit);
                }
                List<AbstractTrainingNode<Regressor>> nodes = node.buildTree(indices, localRNG, getUseRandomSplitPoints());
                // Use the queue as a stack to improve cache locality.
                for (AbstractTrainingNode<Regressor> newNode : nodes) {
                    queue.addFirst(newNode);
                }
            }
        }
        nodeMap.put(dimName, root.convertTree());
    }
    ModelProvenance provenance = new ModelProvenance(TreeModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
    return new IndependentRegressionTreeModel("cart-tree", provenance, featureIDMap, outputIDInfo, false, nodeMap);
}
Also used : InvertedData(org.tribuo.regression.rtree.impl.RegressorTrainingNode.InvertedData) HashMap(java.util.HashMap) AbstractTrainingNode(org.tribuo.common.tree.AbstractTrainingNode) RegressorTrainingNode(org.tribuo.regression.rtree.impl.RegressorTrainingNode) Node(org.tribuo.common.tree.Node) TreeModel(org.tribuo.common.tree.TreeModel) AbstractTrainingNode(org.tribuo.common.tree.AbstractTrainingNode) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) Regressor(org.tribuo.regression.Regressor) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) ArrayDeque(java.util.ArrayDeque) SplittableRandom(java.util.SplittableRandom) RegressorTrainingNode(org.tribuo.regression.rtree.impl.RegressorTrainingNode)

Example 2 with Node

use of org.tribuo.common.tree.Node in project tribuo by oracle.

the class IndependentRegressionTreeModel method getExcuse.

@Override
public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
    SparseVector vec = SparseVector.createSparseVector(example, featureIDMap, false);
    if (vec.numActiveElements() == 0) {
        return Optional.empty();
    }
    List<String> list = new ArrayList<>();
    List<Prediction<Regressor>> predList = new ArrayList<>();
    Map<String, List<Pair<String, Double>>> map = new HashMap<>();
    for (Map.Entry<String, Node<Regressor>> e : roots.entrySet()) {
        list.clear();
        // 
        // Ensures we handle collisions correctly
        Node<Regressor> oldNode = e.getValue();
        Node<Regressor> curNode = e.getValue();
        while (curNode != null) {
            oldNode = curNode;
            if (oldNode instanceof SplitNode) {
                SplitNode<?> node = (SplitNode<?>) curNode;
                list.add(featureIDMap.get(node.getFeatureID()).getName());
            }
            curNode = oldNode.getNextNode(vec);
        }
        // 
        // oldNode must be a LeafNode.
        predList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example));
        List<Pair<String, Double>> pairs = new ArrayList<>();
        int i = list.size() + 1;
        for (String s : list) {
            pairs.add(new Pair<>(s, i + 0.0));
            i--;
        }
        map.put(e.getKey(), pairs);
    }
    Prediction<Regressor> combinedPrediction = combine(predList);
    return Optional.of(new Excuse<>(example, combinedPrediction, map));
}
Also used : HashMap(java.util.HashMap) Prediction(org.tribuo.Prediction) SplitNode(org.tribuo.common.tree.SplitNode) LeafNode(org.tribuo.common.tree.LeafNode) Node(org.tribuo.common.tree.Node) ArrayList(java.util.ArrayList) SparseVector(org.tribuo.math.la.SparseVector) SplitNode(org.tribuo.common.tree.SplitNode) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) List(java.util.List) Regressor(org.tribuo.regression.Regressor) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) HashMap(java.util.HashMap) Map(java.util.Map) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 3 with Node

use of org.tribuo.common.tree.Node in project tribuo by oracle.

the class IndependentRegressionTreeModel method predict.

@Override
public Prediction<Regressor> predict(Example<Regressor> example) {
    // 
    // Ensures we handle collisions correctly
    SparseVector vec = SparseVector.createSparseVector(example, featureIDMap, false);
    if (vec.numActiveElements() == 0) {
        throw new IllegalArgumentException("No features found in Example " + example.toString());
    }
    List<Prediction<Regressor>> predictionList = new ArrayList<>();
    for (Map.Entry<String, Node<Regressor>> e : roots.entrySet()) {
        Node<Regressor> oldNode = e.getValue();
        Node<Regressor> curNode = e.getValue();
        while (curNode != null) {
            oldNode = curNode;
            curNode = oldNode.getNextNode(vec);
        }
        // 
        // oldNode must be a LeafNode.
        predictionList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example));
    }
    return combine(predictionList);
}
Also used : Prediction(org.tribuo.Prediction) SplitNode(org.tribuo.common.tree.SplitNode) LeafNode(org.tribuo.common.tree.LeafNode) Node(org.tribuo.common.tree.Node) ArrayList(java.util.ArrayList) Regressor(org.tribuo.regression.Regressor) SparseVector(org.tribuo.math.la.SparseVector) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) HashMap(java.util.HashMap) Map(java.util.Map)

Aggregations

HashMap (java.util.HashMap)3 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)3 Node (org.tribuo.common.tree.Node)3 Regressor (org.tribuo.regression.Regressor)3 ArrayList (java.util.ArrayList)2 Map (java.util.Map)2 Prediction (org.tribuo.Prediction)2 LeafNode (org.tribuo.common.tree.LeafNode)2 SplitNode (org.tribuo.common.tree.SplitNode)2 SparseVector (org.tribuo.math.la.SparseVector)2 Pair (com.oracle.labs.mlrg.olcut.util.Pair)1 ArrayDeque (java.util.ArrayDeque)1 LinkedList (java.util.LinkedList)1 List (java.util.List)1 SplittableRandom (java.util.SplittableRandom)1 AbstractTrainingNode (org.tribuo.common.tree.AbstractTrainingNode)1 TreeModel (org.tribuo.common.tree.TreeModel)1 ModelProvenance (org.tribuo.provenance.ModelProvenance)1 TrainerProvenance (org.tribuo.provenance.TrainerProvenance)1 RegressorTrainingNode (org.tribuo.regression.rtree.impl.RegressorTrainingNode)1