use of org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator in project ignite by apache.
the class GiniSplitCalculatorTest method testSplitTwoClassesFourPoints.
/**
* Test calculation of split of region consisting from four distinct values.
*/
@Test
public void testSplitTwoClassesFourPoints() {
double[] labels = new double[] { 0.0, 0.0, 1.0, 1.0 };
double[] values = new double[] { 0.0, 1.0, 2.0, 3.0 };
Integer[] samples = new Integer[] { 0, 1, 2, 3 };
int[] cnts = new int[] { 2, 2 };
GiniSplitCalculator.GiniData data = new GiniSplitCalculator.GiniData(0.5, 4, cnts, 2.0 * 2.0 + 2.0 * 2.0);
SplitInfo<GiniSplitCalculator.GiniData> split = new GiniSplitCalculator(labels).splitRegion(samples, values, labels, 0, data);
assert split.leftData().impurity() == 0;
assert split.leftData().counts()[0] == 2;
assert split.leftData().counts()[1] == 0;
assert split.leftData().getSize() == 2;
assert split.rightData().impurity() == 0;
assert split.rightData().counts()[0] == 0;
assert split.rightData().counts()[1] == 2;
assert split.rightData().getSize() == 2;
}
use of org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator in project ignite by apache.
the class GiniSplitCalculatorTest method testSplitThreePoints.
/**
* Test calculation of split of region consisting from three distinct values.
*/
@Test
public void testSplitThreePoints() {
double[] labels = new double[] { 0.0, 1.0, 2.0 };
double[] values = new double[] { 0.0, 1.0, 2.0 };
Integer[] samples = new Integer[] { 0, 1, 2 };
int[] cnts = new int[] { 1, 1, 1 };
GiniSplitCalculator.GiniData data = new GiniSplitCalculator.GiniData(2.0 / 3, 3, cnts, 1.0 * 1.0 + 1.0 * 1.0 + 1.0 * 1.0);
SplitInfo<GiniSplitCalculator.GiniData> split = new GiniSplitCalculator(labels).splitRegion(samples, values, labels, 0, data);
assert split.leftData().impurity() == 0.0;
assert split.leftData().counts()[0] == 1;
assert split.leftData().counts()[1] == 0;
assert split.leftData().counts()[2] == 0;
assert split.leftData().getSize() == 1;
assert split.rightData().impurity() == 0.5;
assert split.rightData().counts()[0] == 0;
assert split.rightData().counts()[1] == 1;
assert split.rightData().counts()[2] == 1;
assert split.rightData().getSize() == 2;
}
use of org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator in project ignite by apache.
the class GiniSplitCalculatorTest method testSplitTwoClassesTwoPoints.
/**
* Test calculation of split of region consisting from two points.
*/
@Test
public void testSplitTwoClassesTwoPoints() {
double[] labels = new double[] { 0.0, 1.0 };
double[] values = new double[] { 0.0, 1.0 };
Integer[] samples = new Integer[] { 0, 1 };
int[] cnts = new int[] { 1, 1 };
GiniSplitCalculator.GiniData data = new GiniSplitCalculator.GiniData(0.5, 2, cnts, 1.0 * 1.0 + 1.0 * 1.0);
SplitInfo<GiniSplitCalculator.GiniData> split = new GiniSplitCalculator(labels).splitRegion(samples, values, labels, 0, data);
assert split.leftData().impurity() == 0;
assert split.leftData().counts()[0] == 1;
assert split.leftData().counts()[1] == 0;
assert split.leftData().getSize() == 1;
assert split.rightData().impurity() == 0;
assert split.rightData().counts()[0] == 0;
assert split.rightData().counts()[1] == 1;
assert split.rightData().getSize() == 1;
}
Aggregations