use of org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate in project knime-core by knime.
the class TreeNumericColumnData method calcBestSplitClassification.
@Override
public NumericSplitCandidate calcBestSplitClassification(final DataMemberships dataMemberships, final ClassificationPriors targetPriors, final TreeTargetNominalColumnData targetColumn, final RandomData rd) {
final TreeEnsembleLearnerConfiguration config = getConfiguration();
final NominalValueRepresentation[] targetVals = targetColumn.getMetaData().getValues();
final boolean useAverageSplitPoints = config.isUseAverageSplitPoints();
final int minChildNodeSize = config.getMinChildSize();
// distribution of target for each attribute value
final int targetCounts = targetVals.length;
final double[] targetCountsLeftOfSplit = new double[targetCounts];
final double[] targetCountsRightOfSplit = targetPriors.getDistribution().clone();
assert targetCountsRightOfSplit.length == targetCounts;
final double totalSumWeight = targetPriors.getNrRecords();
final IImpurity impurityCriterion = targetPriors.getImpurityCriterion();
final boolean useXGBoostMissingValueHandling = config.getMissingValueHandling() == MissingValueHandling.XGBoost;
// get columnMemberships
final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
// missing value handling
boolean branchContainsMissingValues = containsMissingValues();
boolean missingsGoLeft = true;
final int lengthNonMissing = getLengthNonMissing();
final double[] missingTargetCounts = new double[targetCounts];
int lastValidSplitPosition = -1;
double missingWeight = 0;
columnMemberships.goToLast();
do {
final int indexInColumn = columnMemberships.getIndexInColumn();
if (indexInColumn >= lengthNonMissing) {
final double weight = columnMemberships.getRowWeight();
final int classIdx = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
targetCountsRightOfSplit[classIdx] -= weight;
missingTargetCounts[classIdx] += weight;
missingWeight += weight;
} else {
if (lastValidSplitPosition < 0) {
lastValidSplitPosition = indexInColumn;
} else if ((getSorted(lastValidSplitPosition) - getSorted(indexInColumn)) >= EPSILON) {
break;
} else {
lastValidSplitPosition = indexInColumn;
}
}
} while (columnMemberships.previous());
// it is possible that the column contains missing values but in the current branch there are no missing values
branchContainsMissingValues = missingWeight > 0.0;
columnMemberships.reset();
double sumWeightsLeftOfSplit = 0.0;
double sumWeightsRightOfSplit = totalSumWeight - missingWeight;
final double priorImpurity = useXGBoostMissingValueHandling || !branchContainsMissingValues ? targetPriors.getPriorImpurity() : impurityCriterion.getPartitionImpurity(TreeNominalColumnData.subtractMissingClassCounts(targetPriors.getDistribution(), missingTargetCounts), sumWeightsRightOfSplit);
// all values in branch are missing
if (sumWeightsRightOfSplit == 0) {
// it is impossible to determine a split
return null;
}
double bestSplit = Double.NEGATIVE_INFINITY;
// gain for best split point, unnormalized (not using info gain ratio)
double bestGain = Double.NEGATIVE_INFINITY;
// gain for best split, normalized by attribute entropy when
// info gain ratio is used.
double bestGainValueForSplit = Double.NEGATIVE_INFINITY;
final double[] tempArray1 = new double[2];
double[] tempArray2 = new double[2];
double lastSeenValue = Double.NEGATIVE_INFINITY;
boolean mustTestOnNextValueChange = false;
boolean testSplitOnStart = true;
boolean firstIteration = true;
int lastSeenTarget = -1;
int indexInCol = -1;
// We iterate over the instances in the sample/branch instead of the whole data set
while (columnMemberships.next() && (indexInCol = columnMemberships.getIndexInColumn()) < lengthNonMissing) {
final double weight = columnMemberships.getRowWeight();
assert weight >= EPSILON : "Rows with zero row weight should never be seen!";
final double value = getSorted(indexInCol);
final int target = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
final boolean hasValueChanged = (value - lastSeenValue) >= EPSILON;
final boolean hasTargetChanged = lastSeenTarget != target || indexInCol == lastValidSplitPosition;
if (hasTargetChanged && !firstIteration) {
mustTestOnNextValueChange = true;
testSplitOnStart = false;
}
if (!firstIteration && hasValueChanged && (mustTestOnNextValueChange || testSplitOnStart) && sumWeightsLeftOfSplit >= minChildNodeSize && sumWeightsRightOfSplit >= minChildNodeSize) {
double postSplitImpurity;
boolean tempMissingsGoLeft = false;
// missing value handling
if (branchContainsMissingValues && useXGBoostMissingValueHandling) {
final double[] targetCountsLeftPlusMissing = new double[targetCounts];
final double[] targetCountsRightPlusMissing = new double[targetCounts];
for (int i = 0; i < targetCounts; i++) {
targetCountsLeftPlusMissing[i] = targetCountsLeftOfSplit[i] + missingTargetCounts[i];
targetCountsRightPlusMissing[i] = targetCountsRightOfSplit[i] + missingTargetCounts[i];
}
final double[][] temp = new double[2][2];
final double[] postSplitImpurities = new double[2];
// send all missing values left
tempArray1[0] = impurityCriterion.getPartitionImpurity(targetCountsLeftPlusMissing, sumWeightsLeftOfSplit + missingWeight);
tempArray1[1] = impurityCriterion.getPartitionImpurity(targetCountsRightOfSplit, sumWeightsRightOfSplit);
temp[0][0] = sumWeightsLeftOfSplit + missingWeight;
temp[0][1] = sumWeightsRightOfSplit;
postSplitImpurities[0] = impurityCriterion.getPostSplitImpurity(tempArray1, temp[0], totalSumWeight);
// send all missing values right
tempArray1[0] = impurityCriterion.getPartitionImpurity(targetCountsLeftOfSplit, sumWeightsLeftOfSplit);
tempArray1[1] = impurityCriterion.getPartitionImpurity(targetCountsRightPlusMissing, sumWeightsRightOfSplit + missingWeight);
temp[1][0] = sumWeightsLeftOfSplit;
temp[1][1] = sumWeightsRightOfSplit + missingWeight;
postSplitImpurities[1] = impurityCriterion.getPostSplitImpurity(tempArray1, temp[1], totalSumWeight);
// take better split
if (postSplitImpurities[0] < postSplitImpurities[1]) {
postSplitImpurity = postSplitImpurities[0];
tempArray2 = temp[0];
tempMissingsGoLeft = true;
// TODO random tie breaker
} else {
postSplitImpurity = postSplitImpurities[1];
tempArray2 = temp[1];
tempMissingsGoLeft = false;
}
} else {
tempArray1[0] = impurityCriterion.getPartitionImpurity(targetCountsLeftOfSplit, sumWeightsLeftOfSplit);
tempArray1[1] = impurityCriterion.getPartitionImpurity(targetCountsRightOfSplit, sumWeightsRightOfSplit);
tempArray2[0] = sumWeightsLeftOfSplit;
tempArray2[1] = sumWeightsRightOfSplit;
postSplitImpurity = impurityCriterion.getPostSplitImpurity(tempArray1, tempArray2, totalSumWeight);
}
if (postSplitImpurity < priorImpurity) {
// Use absolute gain (IG) for split calculation even
// if the split criterion is information gain ratio (IGR).
// IGR wouldn't work as it favors extreme unfair splits,
// i.e. 1:9999 would have an attribute entropy
// (IGR denominator) of
// 9999/10000*log(9999/10000) + 1/10000*log(1/10000)
// which is ~0.00148
double gain = (priorImpurity - postSplitImpurity);
boolean randomTieBreaker = gain == bestGain ? rd.nextInt(0, 1) == 1 : false;
if (gain > bestGain || randomTieBreaker) {
bestGainValueForSplit = impurityCriterion.getGain(priorImpurity, postSplitImpurity, tempArray2, totalSumWeight);
bestGain = gain;
bestSplit = useAverageSplitPoints ? getCenter(lastSeenValue, value) : lastSeenValue;
// Go with the majority if there are no missing values during training this is because we should
// still provide a missing direction for the case that there are missing values during prediction
missingsGoLeft = branchContainsMissingValues ? tempMissingsGoLeft : sumWeightsLeftOfSplit > sumWeightsRightOfSplit;
}
}
mustTestOnNextValueChange = false;
}
targetCountsLeftOfSplit[target] += weight;
sumWeightsLeftOfSplit += weight;
targetCountsRightOfSplit[target] -= weight;
sumWeightsRightOfSplit -= weight;
lastSeenTarget = target;
lastSeenValue = value;
firstIteration = false;
}
columnMemberships.reset();
if (bestGainValueForSplit < 0.0) {
// (see info gain ratio implementation)
return null;
}
if (useXGBoostMissingValueHandling) {
// return new NumericMissingSplitCandidate(this, bestSplit, bestGainValueForSplit, missingsGoLeft);
return new NumericSplitCandidate(this, bestSplit, bestGainValueForSplit, new BitSet(), missingsGoLeft ? NumericSplitCandidate.MISSINGS_GO_LEFT : NumericSplitCandidate.MISSINGS_GO_RIGHT);
}
return new NumericSplitCandidate(this, bestSplit, bestGainValueForSplit, getMissedRows(columnMemberships), NumericSplitCandidate.NO_MISSINGS);
}
Aggregations