use of org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel in project ignite by apache.
the class RandomForestTrainer method initTrees.
/**
* Creates list of trees.
*
* @param treesQueue Trees queue.
* @return List of trees.
*/
protected ArrayList<RandomForestTreeModel> initTrees(Queue<TreeNode> treesQueue) {
assert featuresPerTree > 0;
ArrayList<RandomForestTreeModel> roots = new ArrayList<>();
List<Integer> allFeatureIds = IntStream.range(0, meta.size()).boxed().collect(Collectors.toList());
for (TreeNode node : treesQueue) {
Collections.shuffle(allFeatureIds, random);
Set<Integer> featuresSubspace = allFeatureIds.stream().limit(featuresPerTree).collect(Collectors.toSet());
roots.add(new RandomForestTreeModel(node, featuresSubspace));
}
return roots;
}
use of org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel in project ignite by apache.
the class RandomForestTrainer method fit.
/**
* Trains model based on the specified data.
*
* @param dataset Dataset.
* @return list of decision trees.
*/
private List<RandomForestTreeModel> fit(Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
Queue<TreeNode> treesQueue = createRootsQueue();
ArrayList<RandomForestTreeModel> roots = initTrees(treesQueue);
Map<Integer, BucketMeta> histMeta = computeHistogramMeta(meta, dataset);
if (histMeta.isEmpty())
return Collections.emptyList();
ImpurityHistogramsComputer<S> histogramsComputer = createImpurityHistogramsComputer();
while (!treesQueue.isEmpty()) {
Map<NodeId, TreeNode> nodesToLearn = getNodesToLearn(treesQueue);
Map<NodeId, ImpurityHistogramsComputer.NodeImpurityHistograms<S>> nodesImpHists = histogramsComputer.aggregateImpurityStatistics(roots, histMeta, nodesToLearn, dataset);
if (nodesToLearn.size() != nodesImpHists.size())
throw new IllegalStateException();
for (NodeId nodeId : nodesImpHists.keySet()) split(treesQueue, nodesToLearn, nodesImpHists.get(nodeId));
}
createLeafStatisticsAggregator().setValuesForLeaves(roots, dataset);
return roots;
}
use of org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel in project ignite by apache.
the class ImpurityHistogramsComputer method aggregateImpurityStatisticsOnPartition.
/**
* Aggregates statistics for impurity computing for each corner nodes for each trees in random forest. This
* algorithm predict corner node in decision tree for learning vector and stocks it to correspond histogram.
*
* @param dataset Dataset.
* @param roots Trees.
* @param histMeta Histogram buckets meta.
* @param part Partition.
* @return Leaf statistics for impurity computing.
*/
private Map<NodeId, NodeImpurityHistograms<S>> aggregateImpurityStatisticsOnPartition(BootstrappedDatasetPartition dataset, ArrayList<RandomForestTreeModel> roots, Map<Integer, BucketMeta> histMeta, Map<NodeId, TreeNode> part) {
Map<NodeId, NodeImpurityHistograms<S>> res = part.keySet().stream().collect(Collectors.toMap(n -> n, NodeImpurityHistograms::new));
dataset.forEach(vector -> {
for (int sampleId = 0; sampleId < vector.counters().length; sampleId++) {
if (vector.counters()[sampleId] == 0)
continue;
RandomForestTreeModel root = roots.get(sampleId);
NodeId key = root.getRootNode().predictNextNodeKey(vector.features());
if (// if we didn't take all nodes from learning queue
!part.containsKey(key))
continue;
NodeImpurityHistograms<S> statistics = res.get(key);
for (Integer featureId : root.getUsedFeatures()) {
BucketMeta meta = histMeta.get(featureId);
if (!statistics.perFeatureStatistics.containsKey(featureId))
statistics.perFeatureStatistics.put(featureId, createImpurityComputerForFeature(sampleId, meta));
S impurityComputer = statistics.perFeatureStatistics.get(featureId);
impurityComputer.addElement(vector);
}
}
});
return res;
}
Aggregations