Search in sources :

Example 1 with TreeDataIndex

use of org.apache.ignite.ml.tree.data.TreeDataIndex in project ignite by apache.

the class MSEImpurityMeasureCalculator method calculate.

/**
 * {@inheritDoc}
 */
@Override
public StepFunction<MSEImpurityMeasure>[] calculate(DecisionTreeData data, TreeFilter filter, int depth) {
    TreeDataIndex idx = null;
    boolean canCalculate;
    if (useIdx) {
        idx = data.createIndexByFilter(depth, filter);
        canCalculate = idx.rowsCount() > 0;
    } else {
        data = data.filter(filter);
        canCalculate = data.getFeatures().length > 0;
    }
    if (canCalculate) {
        int rowsCnt = rowsCount(data, idx);
        int colsCnt = columnsCount(data, idx);
        @SuppressWarnings("unchecked") StepFunction<MSEImpurityMeasure>[] res = new StepFunction[colsCnt];
        double rightYOriginal = 0;
        double rightY2Original = 0;
        for (int i = 0; i < rowsCnt; i++) {
            double lbVal = getLabelValue(data, idx, 0, i);
            rightYOriginal += lbVal;
            rightY2Original += Math.pow(lbVal, 2);
        }
        for (int col = 0; col < res.length; col++) {
            if (!useIdx)
                data.sort(col);
            double[] x = new double[rowsCnt + 1];
            MSEImpurityMeasure[] y = new MSEImpurityMeasure[rowsCnt + 1];
            x[0] = Double.NEGATIVE_INFINITY;
            double leftY = 0;
            double leftY2 = 0;
            double rightY = rightYOriginal;
            double rightY2 = rightY2Original;
            int leftSize = 0;
            for (int i = 0; i <= rowsCnt; i++) {
                if (leftSize > 0) {
                    double lblVal = getLabelValue(data, idx, col, i - 1);
                    leftY += lblVal;
                    leftY2 += Math.pow(lblVal, 2);
                    rightY -= lblVal;
                    rightY2 -= Math.pow(lblVal, 2);
                }
                if (leftSize < rowsCnt)
                    x[leftSize + 1] = getFeatureValue(data, idx, col, i);
                y[leftSize] = new MSEImpurityMeasure(leftY, leftY2, leftSize, rightY, rightY2, rowsCnt - leftSize);
                leftSize++;
            }
            res[col] = new StepFunction<>(x, y);
        }
        return res;
    }
    return null;
}
Also used : TreeDataIndex(org.apache.ignite.ml.tree.data.TreeDataIndex) StepFunction(org.apache.ignite.ml.tree.impurity.util.StepFunction)

Example 2 with TreeDataIndex

use of org.apache.ignite.ml.tree.data.TreeDataIndex in project ignite by apache.

the class GiniImpurityMeasureCalculator method calculate.

/**
 * {@inheritDoc}
 */
@SuppressWarnings("unchecked")
@Override
public StepFunction<GiniImpurityMeasure>[] calculate(DecisionTreeData data, TreeFilter filter, int depth) {
    TreeDataIndex idx = null;
    boolean canCalculate = false;
    if (useIdx) {
        idx = data.createIndexByFilter(depth, filter);
        canCalculate = idx.rowsCount() > 0;
    } else {
        data = data.filter(filter);
        canCalculate = data.getFeatures().length > 0;
    }
    if (canCalculate) {
        int rowsCnt = rowsCount(data, idx);
        int colsCnt = columnsCount(data, idx);
        StepFunction<GiniImpurityMeasure>[] res = new StepFunction[colsCnt];
        long[] right = new long[lbEncoder.size()];
        for (int i = 0; i < rowsCnt; i++) {
            double lb = getLabelValue(data, idx, 0, i);
            right[getLabelCode(lb)]++;
        }
        for (int col = 0; col < res.length; col++) {
            if (!useIdx)
                data.sort(col);
            double[] x = new double[rowsCnt + 1];
            GiniImpurityMeasure[] y = new GiniImpurityMeasure[rowsCnt + 1];
            long[] left = new long[lbEncoder.size()];
            long[] rightCp = Arrays.copyOf(right, right.length);
            int xPtr = 0, yPtr = 0;
            x[xPtr++] = Double.NEGATIVE_INFINITY;
            y[yPtr++] = new GiniImpurityMeasure(Arrays.copyOf(left, left.length), Arrays.copyOf(rightCp, rightCp.length));
            for (int i = 0; i < rowsCnt; i++) {
                double lb = getLabelValue(data, idx, col, i);
                left[getLabelCode(lb)]++;
                rightCp[getLabelCode(lb)]--;
                double featureVal = getFeatureValue(data, idx, col, i);
                if (i < (rowsCnt - 1) && getFeatureValue(data, idx, col, i + 1) == featureVal)
                    continue;
                x[xPtr++] = featureVal;
                y[yPtr++] = new GiniImpurityMeasure(Arrays.copyOf(left, left.length), Arrays.copyOf(rightCp, rightCp.length));
            }
            res[col] = new StepFunction<>(Arrays.copyOf(x, xPtr), Arrays.copyOf(y, yPtr));
        }
        return res;
    }
    return null;
}
Also used : TreeDataIndex(org.apache.ignite.ml.tree.data.TreeDataIndex) StepFunction(org.apache.ignite.ml.tree.impurity.util.StepFunction)

Aggregations

TreeDataIndex (org.apache.ignite.ml.tree.data.TreeDataIndex)2 StepFunction (org.apache.ignite.ml.tree.impurity.util.StepFunction)2