use of org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetPartition in project ignite by apache.
the class ImpurityHistogramsComputer method aggregateImpurityStatisticsOnPartition.
/**
* Aggregates statistics for impurity computing for each corner nodes for each trees in random forest. This
* algorithm predict corner node in decision tree for learning vector and stocks it to correspond histogram.
*
* @param dataset Dataset.
* @param roots Trees.
* @param histMeta Histogram buckets meta.
* @param part Partition.
* @return Leaf statistics for impurity computing.
*/
private Map<NodeId, NodeImpurityHistograms<S>> aggregateImpurityStatisticsOnPartition(BootstrappedDatasetPartition dataset, ArrayList<RandomForestTreeModel> roots, Map<Integer, BucketMeta> histMeta, Map<NodeId, TreeNode> part) {
Map<NodeId, NodeImpurityHistograms<S>> res = part.keySet().stream().collect(Collectors.toMap(n -> n, NodeImpurityHistograms::new));
dataset.forEach(vector -> {
for (int sampleId = 0; sampleId < vector.counters().length; sampleId++) {
if (vector.counters()[sampleId] == 0)
continue;
RandomForestTreeModel root = roots.get(sampleId);
NodeId key = root.getRootNode().predictNextNodeKey(vector.features());
if (// if we didn't take all nodes from learning queue
!part.containsKey(key))
continue;
NodeImpurityHistograms<S> statistics = res.get(key);
for (Integer featureId : root.getUsedFeatures()) {
BucketMeta meta = histMeta.get(featureId);
if (!statistics.perFeatureStatistics.containsKey(featureId))
statistics.perFeatureStatistics.put(featureId, createImpurityComputerForFeature(sampleId, meta));
S impurityComputer = statistics.perFeatureStatistics.get(featureId);
impurityComputer.addElement(vector);
}
}
});
return res;
}
Aggregations