use of org.apache.ignite.ml.tree.randomforest.data.TreeNode in project ignite by apache.
the class RandomForestTrainer method initTrees.
/**
* Creates list of trees.
*
* @param treesQueue Trees queue.
* @return List of trees.
*/
protected ArrayList<RandomForestTreeModel> initTrees(Queue<TreeNode> treesQueue) {
assert featuresPerTree > 0;
ArrayList<RandomForestTreeModel> roots = new ArrayList<>();
List<Integer> allFeatureIds = IntStream.range(0, meta.size()).boxed().collect(Collectors.toList());
for (TreeNode node : treesQueue) {
Collections.shuffle(allFeatureIds, random);
Set<Integer> featuresSubspace = allFeatureIds.stream().limit(featuresPerTree).collect(Collectors.toSet());
roots.add(new RandomForestTreeModel(node, featuresSubspace));
}
return roots;
}
use of org.apache.ignite.ml.tree.randomforest.data.TreeNode in project ignite by apache.
the class RandomForestTrainer method fit.
/**
* Trains model based on the specified data.
*
* @param dataset Dataset.
* @return list of decision trees.
*/
private List<RandomForestTreeModel> fit(Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
Queue<TreeNode> treesQueue = createRootsQueue();
ArrayList<RandomForestTreeModel> roots = initTrees(treesQueue);
Map<Integer, BucketMeta> histMeta = computeHistogramMeta(meta, dataset);
if (histMeta.isEmpty())
return Collections.emptyList();
ImpurityHistogramsComputer<S> histogramsComputer = createImpurityHistogramsComputer();
while (!treesQueue.isEmpty()) {
Map<NodeId, TreeNode> nodesToLearn = getNodesToLearn(treesQueue);
Map<NodeId, ImpurityHistogramsComputer.NodeImpurityHistograms<S>> nodesImpHists = histogramsComputer.aggregateImpurityStatistics(roots, histMeta, nodesToLearn, dataset);
if (nodesToLearn.size() != nodesImpHists.size())
throw new IllegalStateException();
for (NodeId nodeId : nodesImpHists.keySet()) split(treesQueue, nodesToLearn, nodesImpHists.get(nodeId));
}
createLeafStatisticsAggregator().setValuesForLeaves(roots, dataset);
return roots;
}
use of org.apache.ignite.ml.tree.randomforest.data.TreeNode in project ignite by apache.
the class RandomForestTrainer method split.
/**
* Split node with NodeId if need.
*
* @param learningQueue Learning queue.
* @param nodesToLearn Nodes to learn at current iteration.
* @param nodeImpurityHistograms Impurity histograms on current iteration.
*/
private void split(Queue<TreeNode> learningQueue, Map<NodeId, TreeNode> nodesToLearn, ImpurityHistogramsComputer.NodeImpurityHistograms<S> nodeImpurityHistograms) {
TreeNode cornerNode = nodesToLearn.get(nodeImpurityHistograms.getNodeId());
Optional<NodeSplit> bestSplit = nodeImpurityHistograms.findBestSplit();
if (needSplit(cornerNode, bestSplit)) {
List<TreeNode> children = bestSplit.get().split(cornerNode);
learningQueue.addAll(children);
} else {
if (bestSplit.isPresent())
bestSplit.get().createLeaf(cornerNode);
else {
cornerNode.setImpurity(Double.NEGATIVE_INFINITY);
cornerNode.toLeaf(0.0);
}
}
}
use of org.apache.ignite.ml.tree.randomforest.data.TreeNode in project ignite by apache.
the class RandomForestTest method testNeedSplit.
/**
*/
@Test
public void testNeedSplit() {
TreeNode node = new TreeNode(1, 1);
node.setImpurity(1000);
assertTrue(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, node.getImpurity() - minImpDelta * 1.01))));
assertFalse(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, node.getImpurity() - minImpDelta * 0.5))));
assertFalse(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, node.getImpurity()))));
TreeNode child = node.toConditional(0, 0).get(0);
child.setImpurity(1000);
assertFalse(rf.needSplit(child, Optional.of(new NodeSplit(0, 0, child.getImpurity() - minImpDelta * 1.01))));
}
use of org.apache.ignite.ml.tree.randomforest.data.TreeNode in project ignite by apache.
the class ImpurityHistogramsComputer method aggregateImpurityStatisticsOnPartition.
/**
* Aggregates statistics for impurity computing for each corner nodes for each trees in random forest. This
* algorithm predict corner node in decision tree for learning vector and stocks it to correspond histogram.
*
* @param dataset Dataset.
* @param roots Trees.
* @param histMeta Histogram buckets meta.
* @param part Partition.
* @return Leaf statistics for impurity computing.
*/
private Map<NodeId, NodeImpurityHistograms<S>> aggregateImpurityStatisticsOnPartition(BootstrappedDatasetPartition dataset, ArrayList<RandomForestTreeModel> roots, Map<Integer, BucketMeta> histMeta, Map<NodeId, TreeNode> part) {
Map<NodeId, NodeImpurityHistograms<S>> res = part.keySet().stream().collect(Collectors.toMap(n -> n, NodeImpurityHistograms::new));
dataset.forEach(vector -> {
for (int sampleId = 0; sampleId < vector.counters().length; sampleId++) {
if (vector.counters()[sampleId] == 0)
continue;
RandomForestTreeModel root = roots.get(sampleId);
NodeId key = root.getRootNode().predictNextNodeKey(vector.features());
if (// if we didn't take all nodes from learning queue
!part.containsKey(key))
continue;
NodeImpurityHistograms<S> statistics = res.get(key);
for (Integer featureId : root.getUsedFeatures()) {
BucketMeta meta = histMeta.get(featureId);
if (!statistics.perFeatureStatistics.containsKey(featureId))
statistics.perFeatureStatistics.put(featureId, createImpurityComputerForFeature(sampleId, meta));
S impurityComputer = statistics.perFeatureStatistics.get(featureId);
impurityComputer.addElement(vector);
}
}
});
return res;
}
Aggregations