Search in sources :

Example 1 with BootstrappedDatasetPartition

use of org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetPartition in project ignite by apache.

the class ImpurityHistogramsComputer method aggregateImpurityStatisticsOnPartition.

/**
 * Aggregates statistics for impurity computing for each corner nodes for each trees in random forest. This
 * algorithm predict corner node in decision tree for learning vector and stocks it to correspond histogram.
 *
 * @param dataset Dataset.
 * @param roots Trees.
 * @param histMeta Histogram buckets meta.
 * @param part Partition.
 * @return Leaf statistics for impurity computing.
 */
private Map<NodeId, NodeImpurityHistograms<S>> aggregateImpurityStatisticsOnPartition(BootstrappedDatasetPartition dataset, ArrayList<RandomForestTreeModel> roots, Map<Integer, BucketMeta> histMeta, Map<NodeId, TreeNode> part) {
    Map<NodeId, NodeImpurityHistograms<S>> res = part.keySet().stream().collect(Collectors.toMap(n -> n, NodeImpurityHistograms::new));
    dataset.forEach(vector -> {
        for (int sampleId = 0; sampleId < vector.counters().length; sampleId++) {
            if (vector.counters()[sampleId] == 0)
                continue;
            RandomForestTreeModel root = roots.get(sampleId);
            NodeId key = root.getRootNode().predictNextNodeKey(vector.features());
            if (// if we didn't take all nodes from learning queue
            !part.containsKey(key))
                continue;
            NodeImpurityHistograms<S> statistics = res.get(key);
            for (Integer featureId : root.getUsedFeatures()) {
                BucketMeta meta = histMeta.get(featureId);
                if (!statistics.perFeatureStatistics.containsKey(featureId))
                    statistics.perFeatureStatistics.put(featureId, createImpurityComputerForFeature(sampleId, meta));
                S impurityComputer = statistics.perFeatureStatistics.get(featureId);
                impurityComputer.addElement(vector);
            }
        }
    });
    return res;
}
Also used : TreeNode(org.apache.ignite.ml.tree.randomforest.data.TreeNode) BootstrappedVector(org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector) NodeSplit(org.apache.ignite.ml.tree.randomforest.data.NodeSplit) HashMap(java.util.HashMap) BootstrappedDatasetPartition(org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetPartition) Collectors(java.util.stream.Collectors) NodeId(org.apache.ignite.ml.tree.randomforest.data.NodeId) Serializable(java.io.Serializable) ArrayList(java.util.ArrayList) BucketMeta(org.apache.ignite.ml.dataset.feature.BucketMeta) Stream(java.util.stream.Stream) Dataset(org.apache.ignite.ml.dataset.Dataset) Map(java.util.Map) Optional(java.util.Optional) RandomForestTreeModel(org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel) Comparator(java.util.Comparator) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) RandomForestTreeModel(org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel) NodeId(org.apache.ignite.ml.tree.randomforest.data.NodeId) BucketMeta(org.apache.ignite.ml.dataset.feature.BucketMeta)

Aggregations

Serializable (java.io.Serializable)1 ArrayList (java.util.ArrayList)1 Comparator (java.util.Comparator)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1 Optional (java.util.Optional)1 Collectors (java.util.stream.Collectors)1 Stream (java.util.stream.Stream)1 Dataset (org.apache.ignite.ml.dataset.Dataset)1 BucketMeta (org.apache.ignite.ml.dataset.feature.BucketMeta)1 BootstrappedDatasetPartition (org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetPartition)1 BootstrappedVector (org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector)1 EmptyContext (org.apache.ignite.ml.dataset.primitive.context.EmptyContext)1 NodeId (org.apache.ignite.ml.tree.randomforest.data.NodeId)1 NodeSplit (org.apache.ignite.ml.tree.randomforest.data.NodeSplit)1 RandomForestTreeModel (org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel)1 TreeNode (org.apache.ignite.ml.tree.randomforest.data.TreeNode)1