use of org.tribuo.common.tree.LeafNode in project tribuo by oracle.
the class ClassifierTrainingNode method createLeaf.
/**
* Makes a {@link LeafNode}
* @param impurityScore the impurity score for the node.
* @param weightedCounts the weighted label counts of the data in the node.
* @return a {@link LeafNode}
*/
private LeafNode<Label> createLeaf(double impurityScore, float[] weightedCounts) {
double[] normedCounts = Util.normalizeToDistribution(weightedCounts);
double maxScore = Double.NEGATIVE_INFINITY;
Label maxLabel = null;
Map<String, Label> counts = new LinkedHashMap<>();
for (int i = 0; i < weightedCounts.length; i++) {
final double curCount = normedCounts[i];
String name = labelIDMap.getOutput(i).getLabel();
Label label = new Label(name, curCount);
counts.put(name, label);
if (curCount > maxScore) {
maxScore = curCount;
maxLabel = label;
}
}
return new LeafNode<>(impurityScore, maxLabel, counts, true);
}
use of org.tribuo.common.tree.LeafNode in project tribuo by oracle.
the class IndependentRegressionTreeModel method getExcuse.
@Override
public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
SparseVector vec = SparseVector.createSparseVector(example, featureIDMap, false);
if (vec.numActiveElements() == 0) {
return Optional.empty();
}
List<String> list = new ArrayList<>();
List<Prediction<Regressor>> predList = new ArrayList<>();
Map<String, List<Pair<String, Double>>> map = new HashMap<>();
for (Map.Entry<String, Node<Regressor>> e : roots.entrySet()) {
list.clear();
//
// Ensures we handle collisions correctly
Node<Regressor> oldNode = e.getValue();
Node<Regressor> curNode = e.getValue();
while (curNode != null) {
oldNode = curNode;
if (oldNode instanceof SplitNode) {
SplitNode<?> node = (SplitNode<?>) curNode;
list.add(featureIDMap.get(node.getFeatureID()).getName());
}
curNode = oldNode.getNextNode(vec);
}
//
// oldNode must be a LeafNode.
predList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example));
List<Pair<String, Double>> pairs = new ArrayList<>();
int i = list.size() + 1;
for (String s : list) {
pairs.add(new Pair<>(s, i + 0.0));
i--;
}
map.put(e.getKey(), pairs);
}
Prediction<Regressor> combinedPrediction = combine(predList);
return Optional.of(new Excuse<>(example, combinedPrediction, map));
}
use of org.tribuo.common.tree.LeafNode in project tribuo by oracle.
the class RegressorTrainingNode method createLeaf.
/**
* Makes a {@link LeafNode}
* @param impurityScore the impurity score for the node.
* @param leafIndices the indices of the examples to be placed in the node.
* @return A {@link LeafNode}
*/
private LeafNode<Regressor> createLeaf(double impurityScore, int[] leafIndices) {
double mean = 0.0;
double leafWeightSum = 0.0;
double variance = 0.0;
for (int i = 0; i < leafIndices.length; i++) {
int idx = leafIndices[i];
float value = targets[idx];
float weight = weights[idx];
leafWeightSum += weight;
double oldMean = mean;
mean += (weight / leafWeightSum) * (value - oldMean);
variance += weight * (value - oldMean) * (value - mean);
}
variance = leafIndices.length > 1 ? variance / (leafWeightSum - 1) : 0;
DimensionTuple leafPred = new DimensionTuple(dimName, mean, variance);
return new LeafNode<>(impurityScore, leafPred, Collections.emptyMap(), false);
}
use of org.tribuo.common.tree.LeafNode in project tribuo by oracle.
the class IndependentRegressionTreeModel method predict.
@Override
public Prediction<Regressor> predict(Example<Regressor> example) {
//
// Ensures we handle collisions correctly
SparseVector vec = SparseVector.createSparseVector(example, featureIDMap, false);
if (vec.numActiveElements() == 0) {
throw new IllegalArgumentException("No features found in Example " + example.toString());
}
List<Prediction<Regressor>> predictionList = new ArrayList<>();
for (Map.Entry<String, Node<Regressor>> e : roots.entrySet()) {
Node<Regressor> oldNode = e.getValue();
Node<Regressor> curNode = e.getValue();
while (curNode != null) {
oldNode = curNode;
curNode = oldNode.getNextNode(vec);
}
//
// oldNode must be a LeafNode.
predictionList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example));
}
return combine(predictionList);
}
use of org.tribuo.common.tree.LeafNode in project tribuo by oracle.
the class JointRegressorTrainingNode method createLeaf.
/**
* Makes a {@link LeafNode}
* @param impurityScore the impurity score for the node.
* @param leafIndices the indices of the examples to be placed in the node.
* @return A {@link LeafNode}
*/
private LeafNode<Regressor> createLeaf(double impurityScore, int[] leafIndices) {
double leafWeightSum = 0.0;
double[] mean = new double[targets.length];
Regressor leafPred;
if (normalize) {
for (int i = 0; i < leafIndices.length; i++) {
int idx = leafIndices[i];
float weight = weights[idx];
leafWeightSum += weight;
for (int j = 0; j < targets.length; j++) {
float value = targets[j][idx];
double oldMean = mean[j];
mean[j] += (weight / leafWeightSum) * (value - oldMean);
}
}
String[] names = new String[targets.length];
double sum = 0.0;
for (int i = 0; i < targets.length; i++) {
names[i] = labelIDMap.getOutput(i).getNames()[0];
sum += mean[i];
}
// Normalize all the outputs so that they sum to 1.0.
for (int i = 0; i < targets.length; i++) {
mean[i] /= sum;
}
// Both names and mean are in id order, so the regressor constructor
// will convert them to natural order if they are different.
leafPred = new Regressor(names, mean);
} else {
double[] variance = new double[targets.length];
for (int i = 0; i < leafIndices.length; i++) {
int idx = leafIndices[i];
float weight = weights[idx];
leafWeightSum += weight;
for (int j = 0; j < targets.length; j++) {
float value = targets[j][idx];
double oldMean = mean[j];
mean[j] += (weight / leafWeightSum) * (value - oldMean);
variance[j] += weight * (value - oldMean) * (value - mean[j]);
}
}
String[] names = new String[targets.length];
for (int i = 0; i < targets.length; i++) {
names[i] = labelIDMap.getOutput(i).getNames()[0];
variance[i] = leafIndices.length > 1 ? variance[i] / (leafWeightSum - 1) : 0;
}
// Both names, mean and variance are in id order, so the regressor constructor
// will convert them to natural order if they are different.
leafPred = new Regressor(names, mean, variance);
}
return new LeafNode<>(impurityScore, leafPred, Collections.emptyMap(), false);
}
Aggregations