Search in sources :

Example 1 with GiniSplitCalculator

use of org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator in project ignite by apache.

the class GiniSplitCalculatorTest method testSplitTwoClassesFourPoints.

/**
 * Test calculation of split of region consisting from four distinct values.
 */
@Test
public void testSplitTwoClassesFourPoints() {
    double[] labels = new double[] { 0.0, 0.0, 1.0, 1.0 };
    double[] values = new double[] { 0.0, 1.0, 2.0, 3.0 };
    Integer[] samples = new Integer[] { 0, 1, 2, 3 };
    int[] cnts = new int[] { 2, 2 };
    GiniSplitCalculator.GiniData data = new GiniSplitCalculator.GiniData(0.5, 4, cnts, 2.0 * 2.0 + 2.0 * 2.0);
    SplitInfo<GiniSplitCalculator.GiniData> split = new GiniSplitCalculator(labels).splitRegion(samples, values, labels, 0, data);
    assert split.leftData().impurity() == 0;
    assert split.leftData().counts()[0] == 2;
    assert split.leftData().counts()[1] == 0;
    assert split.leftData().getSize() == 2;
    assert split.rightData().impurity() == 0;
    assert split.rightData().counts()[0] == 0;
    assert split.rightData().counts()[1] == 2;
    assert split.rightData().getSize() == 2;
}
Also used : GiniSplitCalculator(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator) Test(org.junit.Test)

Example 2 with GiniSplitCalculator

use of org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator in project ignite by apache.

the class GiniSplitCalculatorTest method testSplitThreePoints.

/**
 * Test calculation of split of region consisting from three distinct values.
 */
@Test
public void testSplitThreePoints() {
    double[] labels = new double[] { 0.0, 1.0, 2.0 };
    double[] values = new double[] { 0.0, 1.0, 2.0 };
    Integer[] samples = new Integer[] { 0, 1, 2 };
    int[] cnts = new int[] { 1, 1, 1 };
    GiniSplitCalculator.GiniData data = new GiniSplitCalculator.GiniData(2.0 / 3, 3, cnts, 1.0 * 1.0 + 1.0 * 1.0 + 1.0 * 1.0);
    SplitInfo<GiniSplitCalculator.GiniData> split = new GiniSplitCalculator(labels).splitRegion(samples, values, labels, 0, data);
    assert split.leftData().impurity() == 0.0;
    assert split.leftData().counts()[0] == 1;
    assert split.leftData().counts()[1] == 0;
    assert split.leftData().counts()[2] == 0;
    assert split.leftData().getSize() == 1;
    assert split.rightData().impurity() == 0.5;
    assert split.rightData().counts()[0] == 0;
    assert split.rightData().counts()[1] == 1;
    assert split.rightData().counts()[2] == 1;
    assert split.rightData().getSize() == 2;
}
Also used : GiniSplitCalculator(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator) Test(org.junit.Test)

Example 3 with GiniSplitCalculator

use of org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator in project ignite by apache.

the class GiniSplitCalculatorTest method testSplitTwoClassesTwoPoints.

/**
 * Test calculation of split of region consisting from two points.
 */
@Test
public void testSplitTwoClassesTwoPoints() {
    double[] labels = new double[] { 0.0, 1.0 };
    double[] values = new double[] { 0.0, 1.0 };
    Integer[] samples = new Integer[] { 0, 1 };
    int[] cnts = new int[] { 1, 1 };
    GiniSplitCalculator.GiniData data = new GiniSplitCalculator.GiniData(0.5, 2, cnts, 1.0 * 1.0 + 1.0 * 1.0);
    SplitInfo<GiniSplitCalculator.GiniData> split = new GiniSplitCalculator(labels).splitRegion(samples, values, labels, 0, data);
    assert split.leftData().impurity() == 0;
    assert split.leftData().counts()[0] == 1;
    assert split.leftData().counts()[1] == 0;
    assert split.leftData().getSize() == 1;
    assert split.rightData().impurity() == 0;
    assert split.rightData().counts()[0] == 0;
    assert split.rightData().counts()[1] == 1;
    assert split.rightData().getSize() == 1;
}
Also used : GiniSplitCalculator(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator) Test(org.junit.Test)

Aggregations

GiniSplitCalculator (org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator)3 Test (org.junit.Test)3