Search in sources :

Example 1 with TreeNode

use of org.apache.ignite.ml.tree.randomforest.data.TreeNode in project ignite by apache.

the class RandomForestTrainer method initTrees.

/**
 * Creates list of trees.
 *
 * @param treesQueue Trees queue.
 * @return List of trees.
 */
protected ArrayList<RandomForestTreeModel> initTrees(Queue<TreeNode> treesQueue) {
    assert featuresPerTree > 0;
    ArrayList<RandomForestTreeModel> roots = new ArrayList<>();
    List<Integer> allFeatureIds = IntStream.range(0, meta.size()).boxed().collect(Collectors.toList());
    for (TreeNode node : treesQueue) {
        Collections.shuffle(allFeatureIds, random);
        Set<Integer> featuresSubspace = allFeatureIds.stream().limit(featuresPerTree).collect(Collectors.toSet());
        roots.add(new RandomForestTreeModel(node, featuresSubspace));
    }
    return roots;
}
Also used : TreeNode(org.apache.ignite.ml.tree.randomforest.data.TreeNode) RandomForestTreeModel(org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel) ArrayList(java.util.ArrayList)

Example 2 with TreeNode

use of org.apache.ignite.ml.tree.randomforest.data.TreeNode in project ignite by apache.

the class RandomForestTrainer method fit.

/**
 * Trains model based on the specified data.
 *
 * @param dataset Dataset.
 * @return list of decision trees.
 */
private List<RandomForestTreeModel> fit(Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
    Queue<TreeNode> treesQueue = createRootsQueue();
    ArrayList<RandomForestTreeModel> roots = initTrees(treesQueue);
    Map<Integer, BucketMeta> histMeta = computeHistogramMeta(meta, dataset);
    if (histMeta.isEmpty())
        return Collections.emptyList();
    ImpurityHistogramsComputer<S> histogramsComputer = createImpurityHistogramsComputer();
    while (!treesQueue.isEmpty()) {
        Map<NodeId, TreeNode> nodesToLearn = getNodesToLearn(treesQueue);
        Map<NodeId, ImpurityHistogramsComputer.NodeImpurityHistograms<S>> nodesImpHists = histogramsComputer.aggregateImpurityStatistics(roots, histMeta, nodesToLearn, dataset);
        if (nodesToLearn.size() != nodesImpHists.size())
            throw new IllegalStateException();
        for (NodeId nodeId : nodesImpHists.keySet()) split(treesQueue, nodesToLearn, nodesImpHists.get(nodeId));
    }
    createLeafStatisticsAggregator().setValuesForLeaves(roots, dataset);
    return roots;
}
Also used : BucketMeta(org.apache.ignite.ml.dataset.feature.BucketMeta) TreeNode(org.apache.ignite.ml.tree.randomforest.data.TreeNode) RandomForestTreeModel(org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel) NodeId(org.apache.ignite.ml.tree.randomforest.data.NodeId)

Example 3 with TreeNode

use of org.apache.ignite.ml.tree.randomforest.data.TreeNode in project ignite by apache.

the class RandomForestTrainer method split.

/**
 * Split node with NodeId if need.
 *
 * @param learningQueue          Learning queue.
 * @param nodesToLearn           Nodes to learn at current iteration.
 * @param nodeImpurityHistograms Impurity histograms on current iteration.
 */
private void split(Queue<TreeNode> learningQueue, Map<NodeId, TreeNode> nodesToLearn, ImpurityHistogramsComputer.NodeImpurityHistograms<S> nodeImpurityHistograms) {
    TreeNode cornerNode = nodesToLearn.get(nodeImpurityHistograms.getNodeId());
    Optional<NodeSplit> bestSplit = nodeImpurityHistograms.findBestSplit();
    if (needSplit(cornerNode, bestSplit)) {
        List<TreeNode> children = bestSplit.get().split(cornerNode);
        learningQueue.addAll(children);
    } else {
        if (bestSplit.isPresent())
            bestSplit.get().createLeaf(cornerNode);
        else {
            cornerNode.setImpurity(Double.NEGATIVE_INFINITY);
            cornerNode.toLeaf(0.0);
        }
    }
}
Also used : NodeSplit(org.apache.ignite.ml.tree.randomforest.data.NodeSplit) TreeNode(org.apache.ignite.ml.tree.randomforest.data.TreeNode)

Example 4 with TreeNode

use of org.apache.ignite.ml.tree.randomforest.data.TreeNode in project ignite by apache.

the class RandomForestTest method testNeedSplit.

/**
 */
@Test
public void testNeedSplit() {
    TreeNode node = new TreeNode(1, 1);
    node.setImpurity(1000);
    assertTrue(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, node.getImpurity() - minImpDelta * 1.01))));
    assertFalse(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, node.getImpurity() - minImpDelta * 0.5))));
    assertFalse(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, node.getImpurity()))));
    TreeNode child = node.toConditional(0, 0).get(0);
    child.setImpurity(1000);
    assertFalse(rf.needSplit(child, Optional.of(new NodeSplit(0, 0, child.getImpurity() - minImpDelta * 1.01))));
}
Also used : NodeSplit(org.apache.ignite.ml.tree.randomforest.data.NodeSplit) TreeNode(org.apache.ignite.ml.tree.randomforest.data.TreeNode) Test(org.junit.Test)

Example 5 with TreeNode

use of org.apache.ignite.ml.tree.randomforest.data.TreeNode in project ignite by apache.

the class ImpurityHistogramsComputer method aggregateImpurityStatisticsOnPartition.

/**
 * Aggregates statistics for impurity computing for each corner nodes for each trees in random forest. This
 * algorithm predict corner node in decision tree for learning vector and stocks it to correspond histogram.
 *
 * @param dataset Dataset.
 * @param roots Trees.
 * @param histMeta Histogram buckets meta.
 * @param part Partition.
 * @return Leaf statistics for impurity computing.
 */
private Map<NodeId, NodeImpurityHistograms<S>> aggregateImpurityStatisticsOnPartition(BootstrappedDatasetPartition dataset, ArrayList<RandomForestTreeModel> roots, Map<Integer, BucketMeta> histMeta, Map<NodeId, TreeNode> part) {
    Map<NodeId, NodeImpurityHistograms<S>> res = part.keySet().stream().collect(Collectors.toMap(n -> n, NodeImpurityHistograms::new));
    dataset.forEach(vector -> {
        for (int sampleId = 0; sampleId < vector.counters().length; sampleId++) {
            if (vector.counters()[sampleId] == 0)
                continue;
            RandomForestTreeModel root = roots.get(sampleId);
            NodeId key = root.getRootNode().predictNextNodeKey(vector.features());
            if (// if we didn't take all nodes from learning queue
            !part.containsKey(key))
                continue;
            NodeImpurityHistograms<S> statistics = res.get(key);
            for (Integer featureId : root.getUsedFeatures()) {
                BucketMeta meta = histMeta.get(featureId);
                if (!statistics.perFeatureStatistics.containsKey(featureId))
                    statistics.perFeatureStatistics.put(featureId, createImpurityComputerForFeature(sampleId, meta));
                S impurityComputer = statistics.perFeatureStatistics.get(featureId);
                impurityComputer.addElement(vector);
            }
        }
    });
    return res;
}
Also used : TreeNode(org.apache.ignite.ml.tree.randomforest.data.TreeNode) BootstrappedVector(org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector) NodeSplit(org.apache.ignite.ml.tree.randomforest.data.NodeSplit) HashMap(java.util.HashMap) BootstrappedDatasetPartition(org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetPartition) Collectors(java.util.stream.Collectors) NodeId(org.apache.ignite.ml.tree.randomforest.data.NodeId) Serializable(java.io.Serializable) ArrayList(java.util.ArrayList) BucketMeta(org.apache.ignite.ml.dataset.feature.BucketMeta) Stream(java.util.stream.Stream) Dataset(org.apache.ignite.ml.dataset.Dataset) Map(java.util.Map) Optional(java.util.Optional) RandomForestTreeModel(org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel) Comparator(java.util.Comparator) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) RandomForestTreeModel(org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel) NodeId(org.apache.ignite.ml.tree.randomforest.data.NodeId) BucketMeta(org.apache.ignite.ml.dataset.feature.BucketMeta)

Aggregations

TreeNode (org.apache.ignite.ml.tree.randomforest.data.TreeNode)5 NodeSplit (org.apache.ignite.ml.tree.randomforest.data.NodeSplit)3 RandomForestTreeModel (org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel)3 ArrayList (java.util.ArrayList)2 BucketMeta (org.apache.ignite.ml.dataset.feature.BucketMeta)2 NodeId (org.apache.ignite.ml.tree.randomforest.data.NodeId)2 Serializable (java.io.Serializable)1 Comparator (java.util.Comparator)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1 Optional (java.util.Optional)1 Collectors (java.util.stream.Collectors)1 Stream (java.util.stream.Stream)1 Dataset (org.apache.ignite.ml.dataset.Dataset)1 BootstrappedDatasetPartition (org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetPartition)1 BootstrappedVector (org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector)1 EmptyContext (org.apache.ignite.ml.dataset.primitive.context.EmptyContext)1 Test (org.junit.Test)1