use of org.tribuo.common.tree.Node in project tribuo by oracle.
the class CARTRegressionTrainer method train.
@Override
public TreeModel<Regressor> train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance, int invocationCount) {
if (examples.getOutputInfo().getUnknownCount() > 0) {
throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
}
// Creates a new RNG, adds one to the invocation count.
SplittableRandom localRNG;
TrainerProvenance trainerProvenance;
synchronized (this) {
if (invocationCount != INCREMENT_INVOCATION_COUNT) {
setInvocationCount(invocationCount);
}
localRNG = rng.split();
trainerProvenance = getProvenance();
trainInvocationCounter++;
}
ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
ImmutableOutputInfo<Regressor> outputIDInfo = examples.getOutputIDInfo();
Set<Regressor> domain = outputIDInfo.getDomain();
int numFeaturesInSplit = Math.min(Math.round(fractionFeaturesInSplit * featureIDMap.size()), featureIDMap.size());
int[] indices;
int[] originalIndices = new int[featureIDMap.size()];
for (int i = 0; i < originalIndices.length; i++) {
originalIndices[i] = i;
}
if (numFeaturesInSplit != featureIDMap.size()) {
indices = new int[numFeaturesInSplit];
} else {
indices = originalIndices;
}
float weightSum = 0.0f;
for (Example<Regressor> e : examples) {
weightSum += e.getWeight();
}
float scaledMinImpurityDecrease = getMinImpurityDecrease() * weightSum;
AbstractTrainingNode.LeafDeterminer leafDeterminer = new AbstractTrainingNode.LeafDeterminer(maxDepth, minChildWeight, scaledMinImpurityDecrease);
InvertedData data = RegressorTrainingNode.invertData(examples);
Map<String, Node<Regressor>> nodeMap = new HashMap<>();
for (Regressor r : domain) {
String dimName = r.getNames()[0];
int dimIdx = outputIDInfo.getID(r);
AbstractTrainingNode<Regressor> root = new RegressorTrainingNode(impurity, data, dimIdx, dimName, examples.size(), featureIDMap, outputIDInfo, leafDeterminer);
Deque<AbstractTrainingNode<Regressor>> queue = new ArrayDeque<>();
queue.add(root);
while (!queue.isEmpty()) {
AbstractTrainingNode<Regressor> node = queue.poll();
if ((node.getImpurity() > 0.0) && (node.getDepth() < maxDepth) && (node.getWeightSum() >= minChildWeight)) {
if (numFeaturesInSplit != featureIDMap.size()) {
Util.randpermInPlace(originalIndices, localRNG);
System.arraycopy(originalIndices, 0, indices, 0, numFeaturesInSplit);
}
List<AbstractTrainingNode<Regressor>> nodes = node.buildTree(indices, localRNG, getUseRandomSplitPoints());
// Use the queue as a stack to improve cache locality.
for (AbstractTrainingNode<Regressor> newNode : nodes) {
queue.addFirst(newNode);
}
}
}
nodeMap.put(dimName, root.convertTree());
}
ModelProvenance provenance = new ModelProvenance(TreeModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
return new IndependentRegressionTreeModel("cart-tree", provenance, featureIDMap, outputIDInfo, false, nodeMap);
}
use of org.tribuo.common.tree.Node in project tribuo by oracle.
the class IndependentRegressionTreeModel method getExcuse.
@Override
public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
SparseVector vec = SparseVector.createSparseVector(example, featureIDMap, false);
if (vec.numActiveElements() == 0) {
return Optional.empty();
}
List<String> list = new ArrayList<>();
List<Prediction<Regressor>> predList = new ArrayList<>();
Map<String, List<Pair<String, Double>>> map = new HashMap<>();
for (Map.Entry<String, Node<Regressor>> e : roots.entrySet()) {
list.clear();
//
// Ensures we handle collisions correctly
Node<Regressor> oldNode = e.getValue();
Node<Regressor> curNode = e.getValue();
while (curNode != null) {
oldNode = curNode;
if (oldNode instanceof SplitNode) {
SplitNode<?> node = (SplitNode<?>) curNode;
list.add(featureIDMap.get(node.getFeatureID()).getName());
}
curNode = oldNode.getNextNode(vec);
}
//
// oldNode must be a LeafNode.
predList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example));
List<Pair<String, Double>> pairs = new ArrayList<>();
int i = list.size() + 1;
for (String s : list) {
pairs.add(new Pair<>(s, i + 0.0));
i--;
}
map.put(e.getKey(), pairs);
}
Prediction<Regressor> combinedPrediction = combine(predList);
return Optional.of(new Excuse<>(example, combinedPrediction, map));
}
use of org.tribuo.common.tree.Node in project tribuo by oracle.
the class IndependentRegressionTreeModel method predict.
@Override
public Prediction<Regressor> predict(Example<Regressor> example) {
//
// Ensures we handle collisions correctly
SparseVector vec = SparseVector.createSparseVector(example, featureIDMap, false);
if (vec.numActiveElements() == 0) {
throw new IllegalArgumentException("No features found in Example " + example.toString());
}
List<Prediction<Regressor>> predictionList = new ArrayList<>();
for (Map.Entry<String, Node<Regressor>> e : roots.entrySet()) {
Node<Regressor> oldNode = e.getValue();
Node<Regressor> curNode = e.getValue();
while (curNode != null) {
oldNode = curNode;
curNode = oldNode.getNextNode(vec);
}
//
// oldNode must be a LeafNode.
predictionList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example));
}
return combine(predictionList);
}
Aggregations