Search in sources :

Example 1 with NodeId

use of org.apache.ignite.ml.tree.randomforest.data.NodeId in project ignite by apache.

the class LeafValuesComputer method mergeLeafStatistics.

/**
 * Merges statistics on labels from several partitions.
 *
 * @param left first partition.
 * @param right second partition.
 * @return Merged statistics.
 */
private Map<NodeId, T> mergeLeafStatistics(Map<NodeId, T> left, Map<NodeId, T> right) {
    if (left == null)
        return right;
    if (right == null)
        return left;
    Set<NodeId> keys = new HashSet<>(left.keySet());
    keys.addAll(right.keySet());
    for (NodeId key : keys) {
        if (!left.containsKey(key))
            left.put(key, right.get(key));
        else if (right.containsKey(key))
            left.put(key, mergeLeafStats(left.get(key), right.get(key)));
    }
    return left;
}
Also used : NodeId(org.apache.ignite.ml.tree.randomforest.data.NodeId) HashSet(java.util.HashSet)

Example 2 with NodeId

use of org.apache.ignite.ml.tree.randomforest.data.NodeId in project ignite by apache.

the class RandomForestTrainer method fit.

/**
 * Trains model based on the specified data.
 *
 * @param dataset Dataset.
 * @return list of decision trees.
 */
private List<RandomForestTreeModel> fit(Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
    Queue<TreeNode> treesQueue = createRootsQueue();
    ArrayList<RandomForestTreeModel> roots = initTrees(treesQueue);
    Map<Integer, BucketMeta> histMeta = computeHistogramMeta(meta, dataset);
    if (histMeta.isEmpty())
        return Collections.emptyList();
    ImpurityHistogramsComputer<S> histogramsComputer = createImpurityHistogramsComputer();
    while (!treesQueue.isEmpty()) {
        Map<NodeId, TreeNode> nodesToLearn = getNodesToLearn(treesQueue);
        Map<NodeId, ImpurityHistogramsComputer.NodeImpurityHistograms<S>> nodesImpHists = histogramsComputer.aggregateImpurityStatistics(roots, histMeta, nodesToLearn, dataset);
        if (nodesToLearn.size() != nodesImpHists.size())
            throw new IllegalStateException();
        for (NodeId nodeId : nodesImpHists.keySet()) split(treesQueue, nodesToLearn, nodesImpHists.get(nodeId));
    }
    createLeafStatisticsAggregator().setValuesForLeaves(roots, dataset);
    return roots;
}
Also used : BucketMeta(org.apache.ignite.ml.dataset.feature.BucketMeta) TreeNode(org.apache.ignite.ml.tree.randomforest.data.TreeNode) RandomForestTreeModel(org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel) NodeId(org.apache.ignite.ml.tree.randomforest.data.NodeId)

Example 3 with NodeId

use of org.apache.ignite.ml.tree.randomforest.data.NodeId in project ignite by apache.

the class ImpurityHistogramsComputer method aggregateImpurityStatisticsOnPartition.

/**
 * Aggregates statistics for impurity computing for each corner nodes for each trees in random forest. This
 * algorithm predict corner node in decision tree for learning vector and stocks it to correspond histogram.
 *
 * @param dataset Dataset.
 * @param roots Trees.
 * @param histMeta Histogram buckets meta.
 * @param part Partition.
 * @return Leaf statistics for impurity computing.
 */
private Map<NodeId, NodeImpurityHistograms<S>> aggregateImpurityStatisticsOnPartition(BootstrappedDatasetPartition dataset, ArrayList<RandomForestTreeModel> roots, Map<Integer, BucketMeta> histMeta, Map<NodeId, TreeNode> part) {
    Map<NodeId, NodeImpurityHistograms<S>> res = part.keySet().stream().collect(Collectors.toMap(n -> n, NodeImpurityHistograms::new));
    dataset.forEach(vector -> {
        for (int sampleId = 0; sampleId < vector.counters().length; sampleId++) {
            if (vector.counters()[sampleId] == 0)
                continue;
            RandomForestTreeModel root = roots.get(sampleId);
            NodeId key = root.getRootNode().predictNextNodeKey(vector.features());
            if (// if we didn't take all nodes from learning queue
            !part.containsKey(key))
                continue;
            NodeImpurityHistograms<S> statistics = res.get(key);
            for (Integer featureId : root.getUsedFeatures()) {
                BucketMeta meta = histMeta.get(featureId);
                if (!statistics.perFeatureStatistics.containsKey(featureId))
                    statistics.perFeatureStatistics.put(featureId, createImpurityComputerForFeature(sampleId, meta));
                S impurityComputer = statistics.perFeatureStatistics.get(featureId);
                impurityComputer.addElement(vector);
            }
        }
    });
    return res;
}
Also used : TreeNode(org.apache.ignite.ml.tree.randomforest.data.TreeNode) BootstrappedVector(org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector) NodeSplit(org.apache.ignite.ml.tree.randomforest.data.NodeSplit) HashMap(java.util.HashMap) BootstrappedDatasetPartition(org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetPartition) Collectors(java.util.stream.Collectors) NodeId(org.apache.ignite.ml.tree.randomforest.data.NodeId) Serializable(java.io.Serializable) ArrayList(java.util.ArrayList) BucketMeta(org.apache.ignite.ml.dataset.feature.BucketMeta) Stream(java.util.stream.Stream) Dataset(org.apache.ignite.ml.dataset.Dataset) Map(java.util.Map) Optional(java.util.Optional) RandomForestTreeModel(org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel) Comparator(java.util.Comparator) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) RandomForestTreeModel(org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel) NodeId(org.apache.ignite.ml.tree.randomforest.data.NodeId) BucketMeta(org.apache.ignite.ml.dataset.feature.BucketMeta)

Example 4 with NodeId

use of org.apache.ignite.ml.tree.randomforest.data.NodeId in project ignite by apache.

the class LeafValuesComputer method computeLeafsStatisticsInPartition.

/**
 * Aggregates statistics on labels from learning dataset for each leaf nodes.
 *
 * @param roots Learned trees.
 * @param leafs List of all leafs.
 * @param data Data.
 * @return Statistics on labels for each leaf nodes.
 */
private Map<NodeId, T> computeLeafsStatisticsInPartition(ArrayList<RandomForestTreeModel> roots, Map<NodeId, TreeNode> leafs, BootstrappedDatasetPartition data) {
    Map<NodeId, T> res = new HashMap<>();
    for (int sampleId = 0; sampleId < roots.size(); sampleId++) {
        final int sampleIdConst = sampleId;
        data.forEach(vec -> {
            NodeId leafId = roots.get(sampleIdConst).getRootNode().predictNextNodeKey(vec.features());
            if (!leafs.containsKey(leafId))
                throw new IllegalStateException();
            if (!res.containsKey(leafId))
                res.put(leafId, createLeafStatsAggregator(sampleIdConst));
            addElementToLeafStatistic(res.get(leafId), vec, sampleIdConst);
        });
    }
    return res;
}
Also used : HashMap(java.util.HashMap) NodeId(org.apache.ignite.ml.tree.randomforest.data.NodeId)

Aggregations

NodeId (org.apache.ignite.ml.tree.randomforest.data.NodeId)4 HashMap (java.util.HashMap)2 BucketMeta (org.apache.ignite.ml.dataset.feature.BucketMeta)2 RandomForestTreeModel (org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel)2 TreeNode (org.apache.ignite.ml.tree.randomforest.data.TreeNode)2 Serializable (java.io.Serializable)1 ArrayList (java.util.ArrayList)1 Comparator (java.util.Comparator)1 HashSet (java.util.HashSet)1 Map (java.util.Map)1 Optional (java.util.Optional)1 Collectors (java.util.stream.Collectors)1 Stream (java.util.stream.Stream)1 Dataset (org.apache.ignite.ml.dataset.Dataset)1 BootstrappedDatasetPartition (org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetPartition)1 BootstrappedVector (org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector)1 EmptyContext (org.apache.ignite.ml.dataset.primitive.context.EmptyContext)1 NodeSplit (org.apache.ignite.ml.tree.randomforest.data.NodeSplit)1