use of org.apache.ignite.ml.tree.data.TreeDataIndex in project ignite by apache.
the class MSEImpurityMeasureCalculator method calculate.
/**
* {@inheritDoc}
*/
@Override
public StepFunction<MSEImpurityMeasure>[] calculate(DecisionTreeData data, TreeFilter filter, int depth) {
TreeDataIndex idx = null;
boolean canCalculate;
if (useIdx) {
idx = data.createIndexByFilter(depth, filter);
canCalculate = idx.rowsCount() > 0;
} else {
data = data.filter(filter);
canCalculate = data.getFeatures().length > 0;
}
if (canCalculate) {
int rowsCnt = rowsCount(data, idx);
int colsCnt = columnsCount(data, idx);
@SuppressWarnings("unchecked") StepFunction<MSEImpurityMeasure>[] res = new StepFunction[colsCnt];
double rightYOriginal = 0;
double rightY2Original = 0;
for (int i = 0; i < rowsCnt; i++) {
double lbVal = getLabelValue(data, idx, 0, i);
rightYOriginal += lbVal;
rightY2Original += Math.pow(lbVal, 2);
}
for (int col = 0; col < res.length; col++) {
if (!useIdx)
data.sort(col);
double[] x = new double[rowsCnt + 1];
MSEImpurityMeasure[] y = new MSEImpurityMeasure[rowsCnt + 1];
x[0] = Double.NEGATIVE_INFINITY;
double leftY = 0;
double leftY2 = 0;
double rightY = rightYOriginal;
double rightY2 = rightY2Original;
int leftSize = 0;
for (int i = 0; i <= rowsCnt; i++) {
if (leftSize > 0) {
double lblVal = getLabelValue(data, idx, col, i - 1);
leftY += lblVal;
leftY2 += Math.pow(lblVal, 2);
rightY -= lblVal;
rightY2 -= Math.pow(lblVal, 2);
}
if (leftSize < rowsCnt)
x[leftSize + 1] = getFeatureValue(data, idx, col, i);
y[leftSize] = new MSEImpurityMeasure(leftY, leftY2, leftSize, rightY, rightY2, rowsCnt - leftSize);
leftSize++;
}
res[col] = new StepFunction<>(x, y);
}
return res;
}
return null;
}
use of org.apache.ignite.ml.tree.data.TreeDataIndex in project ignite by apache.
the class GiniImpurityMeasureCalculator method calculate.
/**
* {@inheritDoc}
*/
@SuppressWarnings("unchecked")
@Override
public StepFunction<GiniImpurityMeasure>[] calculate(DecisionTreeData data, TreeFilter filter, int depth) {
TreeDataIndex idx = null;
boolean canCalculate = false;
if (useIdx) {
idx = data.createIndexByFilter(depth, filter);
canCalculate = idx.rowsCount() > 0;
} else {
data = data.filter(filter);
canCalculate = data.getFeatures().length > 0;
}
if (canCalculate) {
int rowsCnt = rowsCount(data, idx);
int colsCnt = columnsCount(data, idx);
StepFunction<GiniImpurityMeasure>[] res = new StepFunction[colsCnt];
long[] right = new long[lbEncoder.size()];
for (int i = 0; i < rowsCnt; i++) {
double lb = getLabelValue(data, idx, 0, i);
right[getLabelCode(lb)]++;
}
for (int col = 0; col < res.length; col++) {
if (!useIdx)
data.sort(col);
double[] x = new double[rowsCnt + 1];
GiniImpurityMeasure[] y = new GiniImpurityMeasure[rowsCnt + 1];
long[] left = new long[lbEncoder.size()];
long[] rightCp = Arrays.copyOf(right, right.length);
int xPtr = 0, yPtr = 0;
x[xPtr++] = Double.NEGATIVE_INFINITY;
y[yPtr++] = new GiniImpurityMeasure(Arrays.copyOf(left, left.length), Arrays.copyOf(rightCp, rightCp.length));
for (int i = 0; i < rowsCnt; i++) {
double lb = getLabelValue(data, idx, col, i);
left[getLabelCode(lb)]++;
rightCp[getLabelCode(lb)]--;
double featureVal = getFeatureValue(data, idx, col, i);
if (i < (rowsCnt - 1) && getFeatureValue(data, idx, col, i + 1) == featureVal)
continue;
x[xPtr++] = featureVal;
y[yPtr++] = new GiniImpurityMeasure(Arrays.copyOf(left, left.length), Arrays.copyOf(rightCp, rightCp.length));
}
res[col] = new StepFunction<>(Arrays.copyOf(x, xPtr), Arrays.copyOf(y, yPtr));
}
return res;
}
return null;
}
Aggregations