Search in sources :

Example 1 with NodeSplit

use of org.apache.ignite.ml.tree.randomforest.data.NodeSplit in project ignite by apache.

the class RandomForestTrainer method split.

/**
 * Split node with NodeId if need.
 *
 * @param learningQueue          Learning queue.
 * @param nodesToLearn           Nodes to learn at current iteration.
 * @param nodeImpurityHistograms Impurity histograms on current iteration.
 */
private void split(Queue<TreeNode> learningQueue, Map<NodeId, TreeNode> nodesToLearn, ImpurityHistogramsComputer.NodeImpurityHistograms<S> nodeImpurityHistograms) {
    TreeNode cornerNode = nodesToLearn.get(nodeImpurityHistograms.getNodeId());
    Optional<NodeSplit> bestSplit = nodeImpurityHistograms.findBestSplit();
    if (needSplit(cornerNode, bestSplit)) {
        List<TreeNode> children = bestSplit.get().split(cornerNode);
        learningQueue.addAll(children);
    } else {
        if (bestSplit.isPresent())
            bestSplit.get().createLeaf(cornerNode);
        else {
            cornerNode.setImpurity(Double.NEGATIVE_INFINITY);
            cornerNode.toLeaf(0.0);
        }
    }
}
Also used : NodeSplit(org.apache.ignite.ml.tree.randomforest.data.NodeSplit) TreeNode(org.apache.ignite.ml.tree.randomforest.data.TreeNode)

Example 2 with NodeSplit

use of org.apache.ignite.ml.tree.randomforest.data.NodeSplit in project ignite by apache.

the class GiniFeatureHistogramTest method testSplit.

/**
 */
@Test
public void testSplit() {
    Map<Double, Integer> lblMapping = new HashMap<>();
    lblMapping.put(1.0, 0);
    lblMapping.put(2.0, 1);
    GiniHistogram catFeatureSmpl1 = new GiniHistogram(0, lblMapping, feature1Meta);
    GiniHistogram contFeatureSmpl1 = new GiniHistogram(0, lblMapping, feature2Meta);
    GiniHistogram emptyHist = new GiniHistogram(0, lblMapping, feature3Meta);
    GiniHistogram catFeatureSmpl2 = new GiniHistogram(0, lblMapping, feature3Meta);
    feature2Meta.setMinVal(-5);
    feature2Meta.setBucketSize(1);
    for (BootstrappedVector vec : toSplitDataset) {
        catFeatureSmpl1.addElement(vec);
        contFeatureSmpl1.addElement(vec);
        catFeatureSmpl2.addElement(vec);
    }
    NodeSplit catSplit = catFeatureSmpl1.findBestSplit().get();
    NodeSplit contSplit = contFeatureSmpl1.findBestSplit().get();
    assertEquals(1.0, catSplit.getVal(), 0.01);
    assertEquals(-0.5, contSplit.getVal(), 0.01);
    assertFalse(emptyHist.findBestSplit().isPresent());
    assertFalse(catFeatureSmpl2.findBestSplit().isPresent());
}
Also used : NodeSplit(org.apache.ignite.ml.tree.randomforest.data.NodeSplit) HashMap(java.util.HashMap) BootstrappedVector(org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector) Test(org.junit.Test)

Example 3 with NodeSplit

use of org.apache.ignite.ml.tree.randomforest.data.NodeSplit in project ignite by apache.

the class RandomForestTest method testNeedSplit.

/**
 */
@Test
public void testNeedSplit() {
    TreeNode node = new TreeNode(1, 1);
    node.setImpurity(1000);
    assertTrue(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, node.getImpurity() - minImpDelta * 1.01))));
    assertFalse(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, node.getImpurity() - minImpDelta * 0.5))));
    assertFalse(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, node.getImpurity()))));
    TreeNode child = node.toConditional(0, 0).get(0);
    child.setImpurity(1000);
    assertFalse(rf.needSplit(child, Optional.of(new NodeSplit(0, 0, child.getImpurity() - minImpDelta * 1.01))));
}
Also used : NodeSplit(org.apache.ignite.ml.tree.randomforest.data.NodeSplit) TreeNode(org.apache.ignite.ml.tree.randomforest.data.TreeNode) Test(org.junit.Test)

Aggregations

NodeSplit (org.apache.ignite.ml.tree.randomforest.data.NodeSplit)3 TreeNode (org.apache.ignite.ml.tree.randomforest.data.TreeNode)2 Test (org.junit.Test)2 HashMap (java.util.HashMap)1 BootstrappedVector (org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector)1