use of org.apache.ignite.ml.tree.randomforest.data.NodeSplit in project ignite by apache.
the class RandomForestTrainer method split.
/**
* Split node with NodeId if need.
*
* @param learningQueue Learning queue.
* @param nodesToLearn Nodes to learn at current iteration.
* @param nodeImpurityHistograms Impurity histograms on current iteration.
*/
private void split(Queue<TreeNode> learningQueue, Map<NodeId, TreeNode> nodesToLearn, ImpurityHistogramsComputer.NodeImpurityHistograms<S> nodeImpurityHistograms) {
TreeNode cornerNode = nodesToLearn.get(nodeImpurityHistograms.getNodeId());
Optional<NodeSplit> bestSplit = nodeImpurityHistograms.findBestSplit();
if (needSplit(cornerNode, bestSplit)) {
List<TreeNode> children = bestSplit.get().split(cornerNode);
learningQueue.addAll(children);
} else {
if (bestSplit.isPresent())
bestSplit.get().createLeaf(cornerNode);
else {
cornerNode.setImpurity(Double.NEGATIVE_INFINITY);
cornerNode.toLeaf(0.0);
}
}
}
use of org.apache.ignite.ml.tree.randomforest.data.NodeSplit in project ignite by apache.
the class GiniFeatureHistogramTest method testSplit.
/**
*/
@Test
public void testSplit() {
Map<Double, Integer> lblMapping = new HashMap<>();
lblMapping.put(1.0, 0);
lblMapping.put(2.0, 1);
GiniHistogram catFeatureSmpl1 = new GiniHistogram(0, lblMapping, feature1Meta);
GiniHistogram contFeatureSmpl1 = new GiniHistogram(0, lblMapping, feature2Meta);
GiniHistogram emptyHist = new GiniHistogram(0, lblMapping, feature3Meta);
GiniHistogram catFeatureSmpl2 = new GiniHistogram(0, lblMapping, feature3Meta);
feature2Meta.setMinVal(-5);
feature2Meta.setBucketSize(1);
for (BootstrappedVector vec : toSplitDataset) {
catFeatureSmpl1.addElement(vec);
contFeatureSmpl1.addElement(vec);
catFeatureSmpl2.addElement(vec);
}
NodeSplit catSplit = catFeatureSmpl1.findBestSplit().get();
NodeSplit contSplit = contFeatureSmpl1.findBestSplit().get();
assertEquals(1.0, catSplit.getVal(), 0.01);
assertEquals(-0.5, contSplit.getVal(), 0.01);
assertFalse(emptyHist.findBestSplit().isPresent());
assertFalse(catFeatureSmpl2.findBestSplit().isPresent());
}
use of org.apache.ignite.ml.tree.randomforest.data.NodeSplit in project ignite by apache.
the class RandomForestTest method testNeedSplit.
/**
*/
@Test
public void testNeedSplit() {
TreeNode node = new TreeNode(1, 1);
node.setImpurity(1000);
assertTrue(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, node.getImpurity() - minImpDelta * 1.01))));
assertFalse(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, node.getImpurity() - minImpDelta * 0.5))));
assertFalse(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, node.getImpurity()))));
TreeNode child = node.toConditional(0, 0).get(0);
child.setImpurity(1000);
assertFalse(rf.needSplit(child, Optional.of(new NodeSplit(0, 0, child.getImpurity() - minImpDelta * 1.01))));
}
Aggregations