use of org.apache.ignite.ml.tree.randomforest.data.NodeId in project ignite by apache.
the class LeafValuesComputer method mergeLeafStatistics.
/**
* Merges statistics on labels from several partitions.
*
* @param left first partition.
* @param right second partition.
* @return Merged statistics.
*/
private Map<NodeId, T> mergeLeafStatistics(Map<NodeId, T> left, Map<NodeId, T> right) {
if (left == null)
return right;
if (right == null)
return left;
Set<NodeId> keys = new HashSet<>(left.keySet());
keys.addAll(right.keySet());
for (NodeId key : keys) {
if (!left.containsKey(key))
left.put(key, right.get(key));
else if (right.containsKey(key))
left.put(key, mergeLeafStats(left.get(key), right.get(key)));
}
return left;
}
use of org.apache.ignite.ml.tree.randomforest.data.NodeId in project ignite by apache.
the class RandomForestTrainer method fit.
/**
* Trains model based on the specified data.
*
* @param dataset Dataset.
* @return list of decision trees.
*/
private List<RandomForestTreeModel> fit(Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
Queue<TreeNode> treesQueue = createRootsQueue();
ArrayList<RandomForestTreeModel> roots = initTrees(treesQueue);
Map<Integer, BucketMeta> histMeta = computeHistogramMeta(meta, dataset);
if (histMeta.isEmpty())
return Collections.emptyList();
ImpurityHistogramsComputer<S> histogramsComputer = createImpurityHistogramsComputer();
while (!treesQueue.isEmpty()) {
Map<NodeId, TreeNode> nodesToLearn = getNodesToLearn(treesQueue);
Map<NodeId, ImpurityHistogramsComputer.NodeImpurityHistograms<S>> nodesImpHists = histogramsComputer.aggregateImpurityStatistics(roots, histMeta, nodesToLearn, dataset);
if (nodesToLearn.size() != nodesImpHists.size())
throw new IllegalStateException();
for (NodeId nodeId : nodesImpHists.keySet()) split(treesQueue, nodesToLearn, nodesImpHists.get(nodeId));
}
createLeafStatisticsAggregator().setValuesForLeaves(roots, dataset);
return roots;
}
use of org.apache.ignite.ml.tree.randomforest.data.NodeId in project ignite by apache.
the class ImpurityHistogramsComputer method aggregateImpurityStatisticsOnPartition.
/**
* Aggregates statistics for impurity computing for each corner nodes for each trees in random forest. This
* algorithm predict corner node in decision tree for learning vector and stocks it to correspond histogram.
*
* @param dataset Dataset.
* @param roots Trees.
* @param histMeta Histogram buckets meta.
* @param part Partition.
* @return Leaf statistics for impurity computing.
*/
private Map<NodeId, NodeImpurityHistograms<S>> aggregateImpurityStatisticsOnPartition(BootstrappedDatasetPartition dataset, ArrayList<RandomForestTreeModel> roots, Map<Integer, BucketMeta> histMeta, Map<NodeId, TreeNode> part) {
Map<NodeId, NodeImpurityHistograms<S>> res = part.keySet().stream().collect(Collectors.toMap(n -> n, NodeImpurityHistograms::new));
dataset.forEach(vector -> {
for (int sampleId = 0; sampleId < vector.counters().length; sampleId++) {
if (vector.counters()[sampleId] == 0)
continue;
RandomForestTreeModel root = roots.get(sampleId);
NodeId key = root.getRootNode().predictNextNodeKey(vector.features());
if (// if we didn't take all nodes from learning queue
!part.containsKey(key))
continue;
NodeImpurityHistograms<S> statistics = res.get(key);
for (Integer featureId : root.getUsedFeatures()) {
BucketMeta meta = histMeta.get(featureId);
if (!statistics.perFeatureStatistics.containsKey(featureId))
statistics.perFeatureStatistics.put(featureId, createImpurityComputerForFeature(sampleId, meta));
S impurityComputer = statistics.perFeatureStatistics.get(featureId);
impurityComputer.addElement(vector);
}
}
});
return res;
}
use of org.apache.ignite.ml.tree.randomforest.data.NodeId in project ignite by apache.
the class LeafValuesComputer method computeLeafsStatisticsInPartition.
/**
* Aggregates statistics on labels from learning dataset for each leaf nodes.
*
* @param roots Learned trees.
* @param leafs List of all leafs.
* @param data Data.
* @return Statistics on labels for each leaf nodes.
*/
private Map<NodeId, T> computeLeafsStatisticsInPartition(ArrayList<RandomForestTreeModel> roots, Map<NodeId, TreeNode> leafs, BootstrappedDatasetPartition data) {
Map<NodeId, T> res = new HashMap<>();
for (int sampleId = 0; sampleId < roots.size(); sampleId++) {
final int sampleIdConst = sampleId;
data.forEach(vec -> {
NodeId leafId = roots.get(sampleIdConst).getRootNode().predictNextNodeKey(vec.features());
if (!leafs.containsKey(leafId))
throw new IllegalStateException();
if (!res.containsKey(leafId))
res.put(leafId, createLeafStatsAggregator(sampleIdConst));
addElementToLeafStatistic(res.get(leafId), vec, sampleIdConst);
});
}
return res;
}
Aggregations