use of org.apache.ignite.ml.tree.impurity.util.StepFunction in project ignite by apache.
the class MSEImpurityMeasureCalculator method calculate.
/**
* {@inheritDoc}
*/
@Override
public StepFunction<MSEImpurityMeasure>[] calculate(DecisionTreeData data, TreeFilter filter, int depth) {
TreeDataIndex idx = null;
boolean canCalculate;
if (useIdx) {
idx = data.createIndexByFilter(depth, filter);
canCalculate = idx.rowsCount() > 0;
} else {
data = data.filter(filter);
canCalculate = data.getFeatures().length > 0;
}
if (canCalculate) {
int rowsCnt = rowsCount(data, idx);
int colsCnt = columnsCount(data, idx);
@SuppressWarnings("unchecked") StepFunction<MSEImpurityMeasure>[] res = new StepFunction[colsCnt];
double rightYOriginal = 0;
double rightY2Original = 0;
for (int i = 0; i < rowsCnt; i++) {
double lbVal = getLabelValue(data, idx, 0, i);
rightYOriginal += lbVal;
rightY2Original += Math.pow(lbVal, 2);
}
for (int col = 0; col < res.length; col++) {
if (!useIdx)
data.sort(col);
double[] x = new double[rowsCnt + 1];
MSEImpurityMeasure[] y = new MSEImpurityMeasure[rowsCnt + 1];
x[0] = Double.NEGATIVE_INFINITY;
double leftY = 0;
double leftY2 = 0;
double rightY = rightYOriginal;
double rightY2 = rightY2Original;
int leftSize = 0;
for (int i = 0; i <= rowsCnt; i++) {
if (leftSize > 0) {
double lblVal = getLabelValue(data, idx, col, i - 1);
leftY += lblVal;
leftY2 += Math.pow(lblVal, 2);
rightY -= lblVal;
rightY2 -= Math.pow(lblVal, 2);
}
if (leftSize < rowsCnt)
x[leftSize + 1] = getFeatureValue(data, idx, col, i);
y[leftSize] = new MSEImpurityMeasure(leftY, leftY2, leftSize, rightY, rightY2, rowsCnt - leftSize);
leftSize++;
}
res[col] = new StepFunction<>(x, y);
}
return res;
}
return null;
}
use of org.apache.ignite.ml.tree.impurity.util.StepFunction in project ignite by apache.
the class GiniImpurityMeasureCalculator method calculate.
/**
* {@inheritDoc}
*/
@SuppressWarnings("unchecked")
@Override
public StepFunction<GiniImpurityMeasure>[] calculate(DecisionTreeData data, TreeFilter filter, int depth) {
TreeDataIndex idx = null;
boolean canCalculate = false;
if (useIdx) {
idx = data.createIndexByFilter(depth, filter);
canCalculate = idx.rowsCount() > 0;
} else {
data = data.filter(filter);
canCalculate = data.getFeatures().length > 0;
}
if (canCalculate) {
int rowsCnt = rowsCount(data, idx);
int colsCnt = columnsCount(data, idx);
StepFunction<GiniImpurityMeasure>[] res = new StepFunction[colsCnt];
long[] right = new long[lbEncoder.size()];
for (int i = 0; i < rowsCnt; i++) {
double lb = getLabelValue(data, idx, 0, i);
right[getLabelCode(lb)]++;
}
for (int col = 0; col < res.length; col++) {
if (!useIdx)
data.sort(col);
double[] x = new double[rowsCnt + 1];
GiniImpurityMeasure[] y = new GiniImpurityMeasure[rowsCnt + 1];
long[] left = new long[lbEncoder.size()];
long[] rightCp = Arrays.copyOf(right, right.length);
int xPtr = 0, yPtr = 0;
x[xPtr++] = Double.NEGATIVE_INFINITY;
y[yPtr++] = new GiniImpurityMeasure(Arrays.copyOf(left, left.length), Arrays.copyOf(rightCp, rightCp.length));
for (int i = 0; i < rowsCnt; i++) {
double lb = getLabelValue(data, idx, col, i);
left[getLabelCode(lb)]++;
rightCp[getLabelCode(lb)]--;
double featureVal = getFeatureValue(data, idx, col, i);
if (i < (rowsCnt - 1) && getFeatureValue(data, idx, col, i + 1) == featureVal)
continue;
x[xPtr++] = featureVal;
y[yPtr++] = new GiniImpurityMeasure(Arrays.copyOf(left, left.length), Arrays.copyOf(rightCp, rightCp.length));
}
res[col] = new StepFunction<>(Arrays.copyOf(x, xPtr), Arrays.copyOf(y, yPtr));
}
return res;
}
return null;
}
use of org.apache.ignite.ml.tree.impurity.util.StepFunction in project ignite by apache.
the class MSEImpurityMeasureCalculatorTest method testCalculate.
/**
*/
@Test
public void testCalculate() {
double[][] data = new double[][] { { 0, 2 }, { 1, 1 }, { 2, 0 }, { 3, 3 } };
double[] labels = new double[] { 1, 2, 2, 1 };
MSEImpurityMeasureCalculator calculator = new MSEImpurityMeasureCalculator(useIdx);
StepFunction<MSEImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIdx), fs -> true, 0);
assertEquals(2, impurity.length);
// Test MSE calculated for the first column.
assertArrayEquals(new double[] { Double.NEGATIVE_INFINITY, 0, 1, 2, 3 }, impurity[0].getX(), 1e-10);
assertEquals(1.000, impurity[0].getY()[0].impurity(), 1e-3);
assertEquals(0.666, impurity[0].getY()[1].impurity(), 1e-3);
assertEquals(1.000, impurity[0].getY()[2].impurity(), 1e-3);
assertEquals(0.666, impurity[0].getY()[3].impurity(), 1e-3);
assertEquals(1.000, impurity[0].getY()[4].impurity(), 1e-3);
// Test MSE calculated for the second column.
assertArrayEquals(new double[] { Double.NEGATIVE_INFINITY, 0, 1, 2, 3 }, impurity[1].getX(), 1e-10);
assertEquals(1.000, impurity[1].getY()[0].impurity(), 1e-3);
assertEquals(0.666, impurity[1].getY()[1].impurity(), 1e-3);
assertEquals(0.000, impurity[1].getY()[2].impurity(), 1e-3);
assertEquals(0.666, impurity[1].getY()[3].impurity(), 1e-3);
assertEquals(1.000, impurity[1].getY()[4].impurity(), 1e-3);
}
use of org.apache.ignite.ml.tree.impurity.util.StepFunction in project ignite by apache.
the class GiniImpurityMeasureCalculatorTest method testCalculate.
/**
*/
@Test
public void testCalculate() {
double[][] data = new double[][] { { 0, 1 }, { 1, 0 }, { 2, 2 }, { 3, 3 } };
double[] labels = new double[] { 0, 1, 1, 1 };
Map<Double, Integer> encoder = new HashMap<>();
encoder.put(0.0, 0);
encoder.put(1.0, 1);
GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIdx);
StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIdx), fs -> true, 0);
assertEquals(2, impurity.length);
// Check Gini calculated for the first column.
assertArrayEquals(new double[] { Double.NEGATIVE_INFINITY, 0, 1, 2, 3 }, impurity[0].getX(), 1e-10);
assertEquals(-2.500, impurity[0].getY()[0].impurity(), 1e-3);
assertEquals(-4.000, impurity[0].getY()[1].impurity(), 1e-3);
assertEquals(-3.000, impurity[0].getY()[2].impurity(), 1e-3);
assertEquals(-2.666, impurity[0].getY()[3].impurity(), 1e-3);
assertEquals(-2.500, impurity[0].getY()[4].impurity(), 1e-3);
// Check Gini calculated for the second column.
assertArrayEquals(new double[] { Double.NEGATIVE_INFINITY, 0, 1, 2, 3 }, impurity[1].getX(), 1e-10);
assertEquals(-2.500, impurity[1].getY()[0].impurity(), 1e-3);
assertEquals(-2.666, impurity[1].getY()[1].impurity(), 1e-3);
assertEquals(-3.000, impurity[1].getY()[2].impurity(), 1e-3);
assertEquals(-2.666, impurity[1].getY()[3].impurity(), 1e-3);
assertEquals(-2.500, impurity[1].getY()[4].impurity(), 1e-3);
}
use of org.apache.ignite.ml.tree.impurity.util.StepFunction in project ignite by apache.
the class GiniImpurityMeasureCalculatorTest method testCalculateWithRepeatedData.
/**
*/
@Test
public void testCalculateWithRepeatedData() {
double[][] data = new double[][] { { 0 }, { 1 }, { 2 }, { 2 }, { 3 } };
double[] labels = new double[] { 0, 1, 1, 1, 1 };
Map<Double, Integer> encoder = new HashMap<>();
encoder.put(0.0, 0);
encoder.put(1.0, 1);
GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIdx);
StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIdx), fs -> true, 0);
assertEquals(1, impurity.length);
// Check Gini calculated for the first column.
assertArrayEquals(new double[] { Double.NEGATIVE_INFINITY, 0, 1, 2, 3 }, impurity[0].getX(), 1e-10);
assertEquals(-3.400, impurity[0].getY()[0].impurity(), 1e-3);
assertEquals(-5.000, impurity[0].getY()[1].impurity(), 1e-3);
assertEquals(-4.000, impurity[0].getY()[2].impurity(), 1e-3);
assertEquals(-3.500, impurity[0].getY()[3].impurity(), 1e-3);
assertEquals(-3.400, impurity[0].getY()[4].impurity(), 1e-3);
}
Aggregations