Search in sources :

Example 1 with LeafNode

use of org.tribuo.common.tree.LeafNode in project tribuo by oracle.

the class ClassifierTrainingNode method createLeaf.

/**
 * Makes a {@link LeafNode}
 * @param impurityScore the impurity score for the node.
 * @param weightedCounts the weighted label counts of the data in the node.
 * @return a {@link LeafNode}
 */
private LeafNode<Label> createLeaf(double impurityScore, float[] weightedCounts) {
    double[] normedCounts = Util.normalizeToDistribution(weightedCounts);
    double maxScore = Double.NEGATIVE_INFINITY;
    Label maxLabel = null;
    Map<String, Label> counts = new LinkedHashMap<>();
    for (int i = 0; i < weightedCounts.length; i++) {
        final double curCount = normedCounts[i];
        String name = labelIDMap.getOutput(i).getLabel();
        Label label = new Label(name, curCount);
        counts.put(name, label);
        if (curCount > maxScore) {
            maxScore = curCount;
            maxLabel = label;
        }
    }
    return new LeafNode<>(impurityScore, maxLabel, counts, true);
}
Also used : LeafNode(org.tribuo.common.tree.LeafNode) Label(org.tribuo.classification.Label) LinkedHashMap(java.util.LinkedHashMap)

Example 2 with LeafNode

use of org.tribuo.common.tree.LeafNode in project tribuo by oracle.

the class IndependentRegressionTreeModel method getExcuse.

@Override
public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
    SparseVector vec = SparseVector.createSparseVector(example, featureIDMap, false);
    if (vec.numActiveElements() == 0) {
        return Optional.empty();
    }
    List<String> list = new ArrayList<>();
    List<Prediction<Regressor>> predList = new ArrayList<>();
    Map<String, List<Pair<String, Double>>> map = new HashMap<>();
    for (Map.Entry<String, Node<Regressor>> e : roots.entrySet()) {
        list.clear();
        // 
        // Ensures we handle collisions correctly
        Node<Regressor> oldNode = e.getValue();
        Node<Regressor> curNode = e.getValue();
        while (curNode != null) {
            oldNode = curNode;
            if (oldNode instanceof SplitNode) {
                SplitNode<?> node = (SplitNode<?>) curNode;
                list.add(featureIDMap.get(node.getFeatureID()).getName());
            }
            curNode = oldNode.getNextNode(vec);
        }
        // 
        // oldNode must be a LeafNode.
        predList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example));
        List<Pair<String, Double>> pairs = new ArrayList<>();
        int i = list.size() + 1;
        for (String s : list) {
            pairs.add(new Pair<>(s, i + 0.0));
            i--;
        }
        map.put(e.getKey(), pairs);
    }
    Prediction<Regressor> combinedPrediction = combine(predList);
    return Optional.of(new Excuse<>(example, combinedPrediction, map));
}
Also used : HashMap(java.util.HashMap) Prediction(org.tribuo.Prediction) SplitNode(org.tribuo.common.tree.SplitNode) LeafNode(org.tribuo.common.tree.LeafNode) Node(org.tribuo.common.tree.Node) ArrayList(java.util.ArrayList) SparseVector(org.tribuo.math.la.SparseVector) SplitNode(org.tribuo.common.tree.SplitNode) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) List(java.util.List) Regressor(org.tribuo.regression.Regressor) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) HashMap(java.util.HashMap) Map(java.util.Map) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 3 with LeafNode

use of org.tribuo.common.tree.LeafNode in project tribuo by oracle.

the class RegressorTrainingNode method createLeaf.

/**
 * Makes a {@link LeafNode}
 * @param impurityScore the impurity score for the node.
 * @param leafIndices the indices of the examples to be placed in the node.
 * @return A {@link LeafNode}
 */
private LeafNode<Regressor> createLeaf(double impurityScore, int[] leafIndices) {
    double mean = 0.0;
    double leafWeightSum = 0.0;
    double variance = 0.0;
    for (int i = 0; i < leafIndices.length; i++) {
        int idx = leafIndices[i];
        float value = targets[idx];
        float weight = weights[idx];
        leafWeightSum += weight;
        double oldMean = mean;
        mean += (weight / leafWeightSum) * (value - oldMean);
        variance += weight * (value - oldMean) * (value - mean);
    }
    variance = leafIndices.length > 1 ? variance / (leafWeightSum - 1) : 0;
    DimensionTuple leafPred = new DimensionTuple(dimName, mean, variance);
    return new LeafNode<>(impurityScore, leafPred, Collections.emptyMap(), false);
}
Also used : LeafNode(org.tribuo.common.tree.LeafNode) DimensionTuple(org.tribuo.regression.Regressor.DimensionTuple)

Example 4 with LeafNode

use of org.tribuo.common.tree.LeafNode in project tribuo by oracle.

the class IndependentRegressionTreeModel method predict.

@Override
public Prediction<Regressor> predict(Example<Regressor> example) {
    // 
    // Ensures we handle collisions correctly
    SparseVector vec = SparseVector.createSparseVector(example, featureIDMap, false);
    if (vec.numActiveElements() == 0) {
        throw new IllegalArgumentException("No features found in Example " + example.toString());
    }
    List<Prediction<Regressor>> predictionList = new ArrayList<>();
    for (Map.Entry<String, Node<Regressor>> e : roots.entrySet()) {
        Node<Regressor> oldNode = e.getValue();
        Node<Regressor> curNode = e.getValue();
        while (curNode != null) {
            oldNode = curNode;
            curNode = oldNode.getNextNode(vec);
        }
        // 
        // oldNode must be a LeafNode.
        predictionList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example));
    }
    return combine(predictionList);
}
Also used : Prediction(org.tribuo.Prediction) SplitNode(org.tribuo.common.tree.SplitNode) LeafNode(org.tribuo.common.tree.LeafNode) Node(org.tribuo.common.tree.Node) ArrayList(java.util.ArrayList) Regressor(org.tribuo.regression.Regressor) SparseVector(org.tribuo.math.la.SparseVector) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) HashMap(java.util.HashMap) Map(java.util.Map)

Example 5 with LeafNode

use of org.tribuo.common.tree.LeafNode in project tribuo by oracle.

the class JointRegressorTrainingNode method createLeaf.

/**
 * Makes a {@link LeafNode}
 * @param impurityScore the impurity score for the node.
 * @param leafIndices the indices of the examples to be placed in the node.
 * @return A {@link LeafNode}
 */
private LeafNode<Regressor> createLeaf(double impurityScore, int[] leafIndices) {
    double leafWeightSum = 0.0;
    double[] mean = new double[targets.length];
    Regressor leafPred;
    if (normalize) {
        for (int i = 0; i < leafIndices.length; i++) {
            int idx = leafIndices[i];
            float weight = weights[idx];
            leafWeightSum += weight;
            for (int j = 0; j < targets.length; j++) {
                float value = targets[j][idx];
                double oldMean = mean[j];
                mean[j] += (weight / leafWeightSum) * (value - oldMean);
            }
        }
        String[] names = new String[targets.length];
        double sum = 0.0;
        for (int i = 0; i < targets.length; i++) {
            names[i] = labelIDMap.getOutput(i).getNames()[0];
            sum += mean[i];
        }
        // Normalize all the outputs so that they sum to 1.0.
        for (int i = 0; i < targets.length; i++) {
            mean[i] /= sum;
        }
        // Both names and mean are in id order, so the regressor constructor
        // will convert them to natural order if they are different.
        leafPred = new Regressor(names, mean);
    } else {
        double[] variance = new double[targets.length];
        for (int i = 0; i < leafIndices.length; i++) {
            int idx = leafIndices[i];
            float weight = weights[idx];
            leafWeightSum += weight;
            for (int j = 0; j < targets.length; j++) {
                float value = targets[j][idx];
                double oldMean = mean[j];
                mean[j] += (weight / leafWeightSum) * (value - oldMean);
                variance[j] += weight * (value - oldMean) * (value - mean[j]);
            }
        }
        String[] names = new String[targets.length];
        for (int i = 0; i < targets.length; i++) {
            names[i] = labelIDMap.getOutput(i).getNames()[0];
            variance[i] = leafIndices.length > 1 ? variance[i] / (leafWeightSum - 1) : 0;
        }
        // Both names, mean and variance are in id order, so the regressor constructor
        // will convert them to natural order if they are different.
        leafPred = new Regressor(names, mean, variance);
    }
    return new LeafNode<>(impurityScore, leafPred, Collections.emptyMap(), false);
}
Also used : LeafNode(org.tribuo.common.tree.LeafNode) Regressor(org.tribuo.regression.Regressor)

Aggregations

LeafNode (org.tribuo.common.tree.LeafNode)5 Regressor (org.tribuo.regression.Regressor)3 ArrayList (java.util.ArrayList)2 HashMap (java.util.HashMap)2 Map (java.util.Map)2 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)2 Prediction (org.tribuo.Prediction)2 Node (org.tribuo.common.tree.Node)2 SplitNode (org.tribuo.common.tree.SplitNode)2 SparseVector (org.tribuo.math.la.SparseVector)2 Pair (com.oracle.labs.mlrg.olcut.util.Pair)1 LinkedHashMap (java.util.LinkedHashMap)1 LinkedList (java.util.LinkedList)1 List (java.util.List)1 Label (org.tribuo.classification.Label)1 DimensionTuple (org.tribuo.regression.Regressor.DimensionTuple)1