Search in sources :

Example 1 with StepFunction

use of org.apache.ignite.ml.tree.impurity.util.StepFunction in project ignite by apache.

the class MSEImpurityMeasureCalculator method calculate.

/**
 * {@inheritDoc}
 */
@Override
public StepFunction<MSEImpurityMeasure>[] calculate(DecisionTreeData data, TreeFilter filter, int depth) {
    TreeDataIndex idx = null;
    boolean canCalculate;
    if (useIdx) {
        idx = data.createIndexByFilter(depth, filter);
        canCalculate = idx.rowsCount() > 0;
    } else {
        data = data.filter(filter);
        canCalculate = data.getFeatures().length > 0;
    }
    if (canCalculate) {
        int rowsCnt = rowsCount(data, idx);
        int colsCnt = columnsCount(data, idx);
        @SuppressWarnings("unchecked") StepFunction<MSEImpurityMeasure>[] res = new StepFunction[colsCnt];
        double rightYOriginal = 0;
        double rightY2Original = 0;
        for (int i = 0; i < rowsCnt; i++) {
            double lbVal = getLabelValue(data, idx, 0, i);
            rightYOriginal += lbVal;
            rightY2Original += Math.pow(lbVal, 2);
        }
        for (int col = 0; col < res.length; col++) {
            if (!useIdx)
                data.sort(col);
            double[] x = new double[rowsCnt + 1];
            MSEImpurityMeasure[] y = new MSEImpurityMeasure[rowsCnt + 1];
            x[0] = Double.NEGATIVE_INFINITY;
            double leftY = 0;
            double leftY2 = 0;
            double rightY = rightYOriginal;
            double rightY2 = rightY2Original;
            int leftSize = 0;
            for (int i = 0; i <= rowsCnt; i++) {
                if (leftSize > 0) {
                    double lblVal = getLabelValue(data, idx, col, i - 1);
                    leftY += lblVal;
                    leftY2 += Math.pow(lblVal, 2);
                    rightY -= lblVal;
                    rightY2 -= Math.pow(lblVal, 2);
                }
                if (leftSize < rowsCnt)
                    x[leftSize + 1] = getFeatureValue(data, idx, col, i);
                y[leftSize] = new MSEImpurityMeasure(leftY, leftY2, leftSize, rightY, rightY2, rowsCnt - leftSize);
                leftSize++;
            }
            res[col] = new StepFunction<>(x, y);
        }
        return res;
    }
    return null;
}
Also used : TreeDataIndex(org.apache.ignite.ml.tree.data.TreeDataIndex) StepFunction(org.apache.ignite.ml.tree.impurity.util.StepFunction)

Example 2 with StepFunction

use of org.apache.ignite.ml.tree.impurity.util.StepFunction in project ignite by apache.

the class GiniImpurityMeasureCalculator method calculate.

/**
 * {@inheritDoc}
 */
@SuppressWarnings("unchecked")
@Override
public StepFunction<GiniImpurityMeasure>[] calculate(DecisionTreeData data, TreeFilter filter, int depth) {
    TreeDataIndex idx = null;
    boolean canCalculate = false;
    if (useIdx) {
        idx = data.createIndexByFilter(depth, filter);
        canCalculate = idx.rowsCount() > 0;
    } else {
        data = data.filter(filter);
        canCalculate = data.getFeatures().length > 0;
    }
    if (canCalculate) {
        int rowsCnt = rowsCount(data, idx);
        int colsCnt = columnsCount(data, idx);
        StepFunction<GiniImpurityMeasure>[] res = new StepFunction[colsCnt];
        long[] right = new long[lbEncoder.size()];
        for (int i = 0; i < rowsCnt; i++) {
            double lb = getLabelValue(data, idx, 0, i);
            right[getLabelCode(lb)]++;
        }
        for (int col = 0; col < res.length; col++) {
            if (!useIdx)
                data.sort(col);
            double[] x = new double[rowsCnt + 1];
            GiniImpurityMeasure[] y = new GiniImpurityMeasure[rowsCnt + 1];
            long[] left = new long[lbEncoder.size()];
            long[] rightCp = Arrays.copyOf(right, right.length);
            int xPtr = 0, yPtr = 0;
            x[xPtr++] = Double.NEGATIVE_INFINITY;
            y[yPtr++] = new GiniImpurityMeasure(Arrays.copyOf(left, left.length), Arrays.copyOf(rightCp, rightCp.length));
            for (int i = 0; i < rowsCnt; i++) {
                double lb = getLabelValue(data, idx, col, i);
                left[getLabelCode(lb)]++;
                rightCp[getLabelCode(lb)]--;
                double featureVal = getFeatureValue(data, idx, col, i);
                if (i < (rowsCnt - 1) && getFeatureValue(data, idx, col, i + 1) == featureVal)
                    continue;
                x[xPtr++] = featureVal;
                y[yPtr++] = new GiniImpurityMeasure(Arrays.copyOf(left, left.length), Arrays.copyOf(rightCp, rightCp.length));
            }
            res[col] = new StepFunction<>(Arrays.copyOf(x, xPtr), Arrays.copyOf(y, yPtr));
        }
        return res;
    }
    return null;
}
Also used : TreeDataIndex(org.apache.ignite.ml.tree.data.TreeDataIndex) StepFunction(org.apache.ignite.ml.tree.impurity.util.StepFunction)

Example 3 with StepFunction

use of org.apache.ignite.ml.tree.impurity.util.StepFunction in project ignite by apache.

the class MSEImpurityMeasureCalculatorTest method testCalculate.

/**
 */
@Test
public void testCalculate() {
    double[][] data = new double[][] { { 0, 2 }, { 1, 1 }, { 2, 0 }, { 3, 3 } };
    double[] labels = new double[] { 1, 2, 2, 1 };
    MSEImpurityMeasureCalculator calculator = new MSEImpurityMeasureCalculator(useIdx);
    StepFunction<MSEImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIdx), fs -> true, 0);
    assertEquals(2, impurity.length);
    // Test MSE calculated for the first column.
    assertArrayEquals(new double[] { Double.NEGATIVE_INFINITY, 0, 1, 2, 3 }, impurity[0].getX(), 1e-10);
    assertEquals(1.000, impurity[0].getY()[0].impurity(), 1e-3);
    assertEquals(0.666, impurity[0].getY()[1].impurity(), 1e-3);
    assertEquals(1.000, impurity[0].getY()[2].impurity(), 1e-3);
    assertEquals(0.666, impurity[0].getY()[3].impurity(), 1e-3);
    assertEquals(1.000, impurity[0].getY()[4].impurity(), 1e-3);
    // Test MSE calculated for the second column.
    assertArrayEquals(new double[] { Double.NEGATIVE_INFINITY, 0, 1, 2, 3 }, impurity[1].getX(), 1e-10);
    assertEquals(1.000, impurity[1].getY()[0].impurity(), 1e-3);
    assertEquals(0.666, impurity[1].getY()[1].impurity(), 1e-3);
    assertEquals(0.000, impurity[1].getY()[2].impurity(), 1e-3);
    assertEquals(0.666, impurity[1].getY()[3].impurity(), 1e-3);
    assertEquals(1.000, impurity[1].getY()[4].impurity(), 1e-3);
}
Also used : StepFunction(org.apache.ignite.ml.tree.impurity.util.StepFunction) DecisionTreeData(org.apache.ignite.ml.tree.data.DecisionTreeData) Test(org.junit.Test)

Example 4 with StepFunction

use of org.apache.ignite.ml.tree.impurity.util.StepFunction in project ignite by apache.

the class GiniImpurityMeasureCalculatorTest method testCalculate.

/**
 */
@Test
public void testCalculate() {
    double[][] data = new double[][] { { 0, 1 }, { 1, 0 }, { 2, 2 }, { 3, 3 } };
    double[] labels = new double[] { 0, 1, 1, 1 };
    Map<Double, Integer> encoder = new HashMap<>();
    encoder.put(0.0, 0);
    encoder.put(1.0, 1);
    GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIdx);
    StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIdx), fs -> true, 0);
    assertEquals(2, impurity.length);
    // Check Gini calculated for the first column.
    assertArrayEquals(new double[] { Double.NEGATIVE_INFINITY, 0, 1, 2, 3 }, impurity[0].getX(), 1e-10);
    assertEquals(-2.500, impurity[0].getY()[0].impurity(), 1e-3);
    assertEquals(-4.000, impurity[0].getY()[1].impurity(), 1e-3);
    assertEquals(-3.000, impurity[0].getY()[2].impurity(), 1e-3);
    assertEquals(-2.666, impurity[0].getY()[3].impurity(), 1e-3);
    assertEquals(-2.500, impurity[0].getY()[4].impurity(), 1e-3);
    // Check Gini calculated for the second column.
    assertArrayEquals(new double[] { Double.NEGATIVE_INFINITY, 0, 1, 2, 3 }, impurity[1].getX(), 1e-10);
    assertEquals(-2.500, impurity[1].getY()[0].impurity(), 1e-3);
    assertEquals(-2.666, impurity[1].getY()[1].impurity(), 1e-3);
    assertEquals(-3.000, impurity[1].getY()[2].impurity(), 1e-3);
    assertEquals(-2.666, impurity[1].getY()[3].impurity(), 1e-3);
    assertEquals(-2.500, impurity[1].getY()[4].impurity(), 1e-3);
}
Also used : HashMap(java.util.HashMap) StepFunction(org.apache.ignite.ml.tree.impurity.util.StepFunction) DecisionTreeData(org.apache.ignite.ml.tree.data.DecisionTreeData) Test(org.junit.Test)

Example 5 with StepFunction

use of org.apache.ignite.ml.tree.impurity.util.StepFunction in project ignite by apache.

the class GiniImpurityMeasureCalculatorTest method testCalculateWithRepeatedData.

/**
 */
@Test
public void testCalculateWithRepeatedData() {
    double[][] data = new double[][] { { 0 }, { 1 }, { 2 }, { 2 }, { 3 } };
    double[] labels = new double[] { 0, 1, 1, 1, 1 };
    Map<Double, Integer> encoder = new HashMap<>();
    encoder.put(0.0, 0);
    encoder.put(1.0, 1);
    GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIdx);
    StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIdx), fs -> true, 0);
    assertEquals(1, impurity.length);
    // Check Gini calculated for the first column.
    assertArrayEquals(new double[] { Double.NEGATIVE_INFINITY, 0, 1, 2, 3 }, impurity[0].getX(), 1e-10);
    assertEquals(-3.400, impurity[0].getY()[0].impurity(), 1e-3);
    assertEquals(-5.000, impurity[0].getY()[1].impurity(), 1e-3);
    assertEquals(-4.000, impurity[0].getY()[2].impurity(), 1e-3);
    assertEquals(-3.500, impurity[0].getY()[3].impurity(), 1e-3);
    assertEquals(-3.400, impurity[0].getY()[4].impurity(), 1e-3);
}
Also used : HashMap(java.util.HashMap) StepFunction(org.apache.ignite.ml.tree.impurity.util.StepFunction) DecisionTreeData(org.apache.ignite.ml.tree.data.DecisionTreeData) Test(org.junit.Test)

Aggregations

StepFunction (org.apache.ignite.ml.tree.impurity.util.StepFunction)5 DecisionTreeData (org.apache.ignite.ml.tree.data.DecisionTreeData)3 Test (org.junit.Test)3 HashMap (java.util.HashMap)2 TreeDataIndex (org.apache.ignite.ml.tree.data.TreeDataIndex)2