use of org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate in project knime-core by knime.
the class TreeNumericColumnDataTest method testCalcBestSplitClassificationMissingValStrategy1.
/**
* This test is outdated and will likely be removed soon.
*
* @throws Exception
*/
// @Test
public void testCalcBestSplitClassificationMissingValStrategy1() throws Exception {
TreeEnsembleLearnerConfiguration config = createConfig();
final double[] data = asDataArray("1, 2, 3, 4, 5, 6, 7, NaN, NaN, NaN");
final String[] target = asStringArray("Y, Y, Y, Y, N, N, N, Y, Y, Y");
Pair<TreeOrdinaryNumericColumnData, TreeTargetNominalColumnData> exampleData = exampleData(config, data, target);
double[] rowWeights = new double[data.length];
Arrays.fill(rowWeights, 1.0);
RandomData rd = config.createRandomData();
TreeNumericColumnData columnData = exampleData.getFirst();
TreeTargetNominalColumnData targetData = exampleData.getSecond();
TreeData treeData = createTreeDataClassification(exampleData);
IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, rd);
assertNotNull(splitCandidate);
assertThat(splitCandidate, instanceOf(NumericMissingSplitCandidate.class));
assertTrue(splitCandidate.canColumnBeSplitFurther());
assertEquals(0.42, splitCandidate.getGainValue(), 0.0001);
TreeNodeNumericCondition[] childConditions = ((NumericMissingSplitCandidate) splitCandidate).getChildConditions();
assertEquals(2, childConditions.length);
assertEquals(NumericOperator.LessThanOrEqualOrMissing, childConditions[0].getNumericOperator());
assertEquals(NumericOperator.LargerThan, childConditions[1].getNumericOperator());
assertEquals(4.5, childConditions[0].getSplitValue(), 0.0);
}
use of org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate in project knime-core by knime.
the class TreeNumericColumnData method calcBestSplitRegression.
@Override
public SplitCandidate calcBestSplitRegression(final DataMemberships dataMemberships, final RegressionPriors targetPriors, final TreeTargetNumericColumnData targetColumn, final RandomData rd) {
final TreeEnsembleLearnerConfiguration config = getConfiguration();
final boolean useAverageSplitPoints = config.isUseAverageSplitPoints();
final int minChildNodeSize = config.getMinChildSize();
// get columnMemberships
final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
final int lengthNonMissing = getLengthNonMissing();
// missing value handling
final boolean useXGBoostMissingValueHandling = config.getMissingValueHandling() == MissingValueHandling.XGBoost;
// are there missing values in this column (complete column)
boolean branchContainsMissingValues = containsMissingValues();
boolean missingsGoLeft = true;
double missingWeight = 0.0;
double missingY = 0.0;
// check if there are missing values in this rowsample
if (branchContainsMissingValues) {
columnMemberships.goToLast();
while (columnMemberships.getIndexInColumn() >= lengthNonMissing) {
missingWeight += columnMemberships.getRowWeight();
missingY += targetColumn.getValueFor(columnMemberships.getOriginalIndex());
if (!columnMemberships.previous()) {
break;
}
}
columnMemberships.reset();
branchContainsMissingValues = missingWeight > 0.0;
}
final double ySumTotal = targetPriors.getYSum() - missingY;
final double nrRecordsTotal = targetPriors.getNrRecords() - missingWeight;
final double criterionTotal = useXGBoostMissingValueHandling ? (ySumTotal + missingY) * (ySumTotal + missingY) / (nrRecordsTotal + missingWeight) : ySumTotal * ySumTotal / nrRecordsTotal;
double ySumLeft = 0.0;
double nrRecordsLeft = 0.0;
double ySumRight = ySumTotal;
double nrRecordsRight = nrRecordsTotal;
// all values in the current branch are missing
if (nrRecordsRight == 0) {
// it is impossible to determine a split
return null;
}
double bestSplit = Double.NEGATIVE_INFINITY;
double bestImprovement = 0.0;
double lastSeenY = Double.NaN;
double lastSeenValue = Double.NEGATIVE_INFINITY;
double lastSeenWeight = -1.0;
// compute the gain, keep the one that maximizes the split
while (columnMemberships.next()) {
final double weight = columnMemberships.getRowWeight();
if (weight < EPSILON) {
// ignore record: not in current branch or not in sample
continue;
} else if (Math.floor(weight) != weight) {
throw new UnsupportedOperationException("weighted records (missing values?) not supported, " + "weight is " + weight);
}
final double value = getSorted(columnMemberships.getIndexInColumn());
if (lastSeenWeight > 0.0) {
ySumLeft += lastSeenWeight * lastSeenY;
ySumRight -= lastSeenWeight * lastSeenY;
nrRecordsLeft += lastSeenWeight;
nrRecordsRight -= lastSeenWeight;
if (nrRecordsLeft >= minChildNodeSize && nrRecordsRight >= minChildNodeSize && lastSeenValue < value) {
boolean tempMissingsGoLeft = true;
double childrenSquaredSum;
if (branchContainsMissingValues && useXGBoostMissingValueHandling) {
final double[] tempChildrenSquaredSum = new double[2];
tempChildrenSquaredSum[0] = ((ySumLeft + missingY) * (ySumLeft + missingY) / (nrRecordsLeft + missingWeight)) + (ySumRight * ySumRight / nrRecordsRight);
tempChildrenSquaredSum[1] = (ySumLeft * ySumLeft / nrRecordsLeft) + ((ySumRight + missingY) * (ySumRight + missingY) / (nrRecordsRight + missingWeight));
if (tempChildrenSquaredSum[0] >= tempChildrenSquaredSum[1]) {
childrenSquaredSum = tempChildrenSquaredSum[0];
tempMissingsGoLeft = true;
} else {
childrenSquaredSum = tempChildrenSquaredSum[1];
tempMissingsGoLeft = false;
}
} else {
childrenSquaredSum = (ySumLeft * ySumLeft / nrRecordsLeft) + (ySumRight * ySumRight / nrRecordsRight);
}
double criterion = childrenSquaredSum - criterionTotal;
boolean randomTieBreaker = criterion == bestImprovement ? rd.nextInt(0, 1) == 1 : false;
if (criterion > bestImprovement || randomTieBreaker) {
bestImprovement = criterion;
bestSplit = useAverageSplitPoints ? getCenter(lastSeenValue, value) : lastSeenValue;
// if there are no missing values go with majority
missingsGoLeft = branchContainsMissingValues ? tempMissingsGoLeft : nrRecordsLeft >= nrRecordsRight;
}
}
}
lastSeenY = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
lastSeenValue = value;
lastSeenWeight = weight;
}
// + " but was " + lastSeenY * lastSeenWeight;
if (bestImprovement > 0.0) {
if (useXGBoostMissingValueHandling) {
// return new NumericMissingSplitCandidate(this, bestSplit, bestImprovement, missingsGoLeft);
return new NumericSplitCandidate(this, bestSplit, bestImprovement, new BitSet(), missingsGoLeft ? NumericSplitCandidate.MISSINGS_GO_LEFT : NumericSplitCandidate.MISSINGS_GO_RIGHT);
}
return new NumericSplitCandidate(this, bestSplit, bestImprovement, getMissedRows(columnMemberships), NumericSplitCandidate.NO_MISSINGS);
} else {
return null;
}
}
use of org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate in project knime-core by knime.
the class TreeNumericColumnData method calcBestSplitClassification.
@Override
public NumericSplitCandidate calcBestSplitClassification(final DataMemberships dataMemberships, final ClassificationPriors targetPriors, final TreeTargetNominalColumnData targetColumn, final RandomData rd) {
final TreeEnsembleLearnerConfiguration config = getConfiguration();
final NominalValueRepresentation[] targetVals = targetColumn.getMetaData().getValues();
final boolean useAverageSplitPoints = config.isUseAverageSplitPoints();
final int minChildNodeSize = config.getMinChildSize();
// distribution of target for each attribute value
final int targetCounts = targetVals.length;
final double[] targetCountsLeftOfSplit = new double[targetCounts];
final double[] targetCountsRightOfSplit = targetPriors.getDistribution().clone();
assert targetCountsRightOfSplit.length == targetCounts;
final double totalSumWeight = targetPriors.getNrRecords();
final IImpurity impurityCriterion = targetPriors.getImpurityCriterion();
final boolean useXGBoostMissingValueHandling = config.getMissingValueHandling() == MissingValueHandling.XGBoost;
// get columnMemberships
final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
// missing value handling
boolean branchContainsMissingValues = containsMissingValues();
boolean missingsGoLeft = true;
final int lengthNonMissing = getLengthNonMissing();
final double[] missingTargetCounts = new double[targetCounts];
int lastValidSplitPosition = -1;
double missingWeight = 0;
columnMemberships.goToLast();
do {
final int indexInColumn = columnMemberships.getIndexInColumn();
if (indexInColumn >= lengthNonMissing) {
final double weight = columnMemberships.getRowWeight();
final int classIdx = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
targetCountsRightOfSplit[classIdx] -= weight;
missingTargetCounts[classIdx] += weight;
missingWeight += weight;
} else {
if (lastValidSplitPosition < 0) {
lastValidSplitPosition = indexInColumn;
} else if ((getSorted(lastValidSplitPosition) - getSorted(indexInColumn)) >= EPSILON) {
break;
} else {
lastValidSplitPosition = indexInColumn;
}
}
} while (columnMemberships.previous());
// it is possible that the column contains missing values but in the current branch there are no missing values
branchContainsMissingValues = missingWeight > 0.0;
columnMemberships.reset();
double sumWeightsLeftOfSplit = 0.0;
double sumWeightsRightOfSplit = totalSumWeight - missingWeight;
final double priorImpurity = useXGBoostMissingValueHandling || !branchContainsMissingValues ? targetPriors.getPriorImpurity() : impurityCriterion.getPartitionImpurity(TreeNominalColumnData.subtractMissingClassCounts(targetPriors.getDistribution(), missingTargetCounts), sumWeightsRightOfSplit);
// all values in branch are missing
if (sumWeightsRightOfSplit == 0) {
// it is impossible to determine a split
return null;
}
double bestSplit = Double.NEGATIVE_INFINITY;
// gain for best split point, unnormalized (not using info gain ratio)
double bestGain = Double.NEGATIVE_INFINITY;
// gain for best split, normalized by attribute entropy when
// info gain ratio is used.
double bestGainValueForSplit = Double.NEGATIVE_INFINITY;
final double[] tempArray1 = new double[2];
double[] tempArray2 = new double[2];
double lastSeenValue = Double.NEGATIVE_INFINITY;
boolean mustTestOnNextValueChange = false;
boolean testSplitOnStart = true;
boolean firstIteration = true;
int lastSeenTarget = -1;
int indexInCol = -1;
// We iterate over the instances in the sample/branch instead of the whole data set
while (columnMemberships.next() && (indexInCol = columnMemberships.getIndexInColumn()) < lengthNonMissing) {
final double weight = columnMemberships.getRowWeight();
assert weight >= EPSILON : "Rows with zero row weight should never be seen!";
final double value = getSorted(indexInCol);
final int target = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
final boolean hasValueChanged = (value - lastSeenValue) >= EPSILON;
final boolean hasTargetChanged = lastSeenTarget != target || indexInCol == lastValidSplitPosition;
if (hasTargetChanged && !firstIteration) {
mustTestOnNextValueChange = true;
testSplitOnStart = false;
}
if (!firstIteration && hasValueChanged && (mustTestOnNextValueChange || testSplitOnStart) && sumWeightsLeftOfSplit >= minChildNodeSize && sumWeightsRightOfSplit >= minChildNodeSize) {
double postSplitImpurity;
boolean tempMissingsGoLeft = false;
// missing value handling
if (branchContainsMissingValues && useXGBoostMissingValueHandling) {
final double[] targetCountsLeftPlusMissing = new double[targetCounts];
final double[] targetCountsRightPlusMissing = new double[targetCounts];
for (int i = 0; i < targetCounts; i++) {
targetCountsLeftPlusMissing[i] = targetCountsLeftOfSplit[i] + missingTargetCounts[i];
targetCountsRightPlusMissing[i] = targetCountsRightOfSplit[i] + missingTargetCounts[i];
}
final double[][] temp = new double[2][2];
final double[] postSplitImpurities = new double[2];
// send all missing values left
tempArray1[0] = impurityCriterion.getPartitionImpurity(targetCountsLeftPlusMissing, sumWeightsLeftOfSplit + missingWeight);
tempArray1[1] = impurityCriterion.getPartitionImpurity(targetCountsRightOfSplit, sumWeightsRightOfSplit);
temp[0][0] = sumWeightsLeftOfSplit + missingWeight;
temp[0][1] = sumWeightsRightOfSplit;
postSplitImpurities[0] = impurityCriterion.getPostSplitImpurity(tempArray1, temp[0], totalSumWeight);
// send all missing values right
tempArray1[0] = impurityCriterion.getPartitionImpurity(targetCountsLeftOfSplit, sumWeightsLeftOfSplit);
tempArray1[1] = impurityCriterion.getPartitionImpurity(targetCountsRightPlusMissing, sumWeightsRightOfSplit + missingWeight);
temp[1][0] = sumWeightsLeftOfSplit;
temp[1][1] = sumWeightsRightOfSplit + missingWeight;
postSplitImpurities[1] = impurityCriterion.getPostSplitImpurity(tempArray1, temp[1], totalSumWeight);
// take better split
if (postSplitImpurities[0] < postSplitImpurities[1]) {
postSplitImpurity = postSplitImpurities[0];
tempArray2 = temp[0];
tempMissingsGoLeft = true;
// TODO random tie breaker
} else {
postSplitImpurity = postSplitImpurities[1];
tempArray2 = temp[1];
tempMissingsGoLeft = false;
}
} else {
tempArray1[0] = impurityCriterion.getPartitionImpurity(targetCountsLeftOfSplit, sumWeightsLeftOfSplit);
tempArray1[1] = impurityCriterion.getPartitionImpurity(targetCountsRightOfSplit, sumWeightsRightOfSplit);
tempArray2[0] = sumWeightsLeftOfSplit;
tempArray2[1] = sumWeightsRightOfSplit;
postSplitImpurity = impurityCriterion.getPostSplitImpurity(tempArray1, tempArray2, totalSumWeight);
}
if (postSplitImpurity < priorImpurity) {
// Use absolute gain (IG) for split calculation even
// if the split criterion is information gain ratio (IGR).
// IGR wouldn't work as it favors extreme unfair splits,
// i.e. 1:9999 would have an attribute entropy
// (IGR denominator) of
// 9999/10000*log(9999/10000) + 1/10000*log(1/10000)
// which is ~0.00148
double gain = (priorImpurity - postSplitImpurity);
boolean randomTieBreaker = gain == bestGain ? rd.nextInt(0, 1) == 1 : false;
if (gain > bestGain || randomTieBreaker) {
bestGainValueForSplit = impurityCriterion.getGain(priorImpurity, postSplitImpurity, tempArray2, totalSumWeight);
bestGain = gain;
bestSplit = useAverageSplitPoints ? getCenter(lastSeenValue, value) : lastSeenValue;
// Go with the majority if there are no missing values during training this is because we should
// still provide a missing direction for the case that there are missing values during prediction
missingsGoLeft = branchContainsMissingValues ? tempMissingsGoLeft : sumWeightsLeftOfSplit > sumWeightsRightOfSplit;
}
}
mustTestOnNextValueChange = false;
}
targetCountsLeftOfSplit[target] += weight;
sumWeightsLeftOfSplit += weight;
targetCountsRightOfSplit[target] -= weight;
sumWeightsRightOfSplit -= weight;
lastSeenTarget = target;
lastSeenValue = value;
firstIteration = false;
}
columnMemberships.reset();
if (bestGainValueForSplit < 0.0) {
// (see info gain ratio implementation)
return null;
}
if (useXGBoostMissingValueHandling) {
// return new NumericMissingSplitCandidate(this, bestSplit, bestGainValueForSplit, missingsGoLeft);
return new NumericSplitCandidate(this, bestSplit, bestGainValueForSplit, new BitSet(), missingsGoLeft ? NumericSplitCandidate.MISSINGS_GO_LEFT : NumericSplitCandidate.MISSINGS_GO_RIGHT);
}
return new NumericSplitCandidate(this, bestSplit, bestGainValueForSplit, getMissedRows(columnMemberships), NumericSplitCandidate.NO_MISSINGS);
}
Aggregations