use of org.knime.base.node.mine.treeensemble2.learner.IImpurity in project knime-core by knime.
the class TreeNominalColumnData method calcBestSplitClassificationBinaryTwoClass.
private NominalBinarySplitCandidate calcBestSplitClassificationBinaryTwoClass(final ColumnMemberships columnMemberships, final ClassificationPriors targetPriors, final TreeTargetNominalColumnData targetColumn, final IImpurity impCriterion, final NominalValueRepresentation[] nomVals, final NominalValueRepresentation[] targetVals, final RandomData rd) {
if (targetColumn.getMetaData().getValues().length != 2) {
throw new IllegalArgumentException("This method can only be used for two class problems.");
}
final TreeEnsembleLearnerConfiguration config = getConfiguration();
final int minChildSize = config.getMinChildSize();
final boolean useXGBoostMissingValueHandling = config.getMissingValueHandling() == MissingValueHandling.XGBoost;
int start = 0;
final int firstClass = targetColumn.getMetaData().getValues()[0].getAssignedInteger();
double totalWeight = 0.0;
double totalFirstClassWeight = 0.0;
final ArrayList<NomValProbabilityPair> nomValProbabilities = new ArrayList<NomValProbabilityPair>();
if (!columnMemberships.next()) {
throw new IllegalStateException("The columnMemberships has not been reset or is empty.");
}
final int lengthNonMissing = containsMissingValues() ? nomVals.length - 1 : nomVals.length;
// final int attToConsider = useXGBoostMissingValueHandling ? nomVals.length : lengthNonMissing;
boolean branchContainsMissingValues = containsMissingValues();
// calculate probabilities for first class in each nominal value
for (int att = 0; att < /*attToConsider*/
lengthNonMissing; att++) {
int end = start + m_nominalValueCounts[att];
double attFirstClassWeight = 0;
double attWeight = 0;
boolean reachedEnd = false;
for (int index = columnMemberships.getIndexInColumn(); index < end; index = columnMemberships.getIndexInColumn()) {
double weight = columnMemberships.getRowWeight();
assert weight > EPSILON : "Instances in columnMemberships must have weights larger than EPSILON.";
final int instanceClass = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
if (instanceClass == firstClass) {
attFirstClassWeight += weight;
totalFirstClassWeight += weight;
}
attWeight += weight;
totalWeight += weight;
if (!columnMemberships.next()) {
// reached end of columnMemberships
reachedEnd = true;
if (att == nomVals.length - 1) {
// if the column contains no missing values, the last possible nominal value is
// not the missing value and therefore branchContainsMissingValues needs to be false
branchContainsMissingValues = branchContainsMissingValues && true;
}
break;
}
}
if (attWeight > 0) {
final double firstClassProbability = attFirstClassWeight / attWeight;
final NominalValueRepresentation nomVal = getMetaData().getValues()[att];
nomValProbabilities.add(new NomValProbabilityPair(nomVal, firstClassProbability, attWeight, attFirstClassWeight));
}
start = end;
if (reachedEnd) {
break;
}
}
// account for missing values and their weight
double missingWeight = 0.0;
double missingWeightFirstClass = 0.0;
// otherwise the current indexInColumn won't be larger than start
if (columnMemberships.getIndexInColumn() >= start) {
do {
final double recordWeight = columnMemberships.getRowWeight();
missingWeight += recordWeight;
final int recordClass = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
if (recordClass == firstClass) {
missingWeightFirstClass += recordWeight;
}
} while (columnMemberships.next());
}
if (missingWeight > EPSILON) {
branchContainsMissingValues = true;
}
nomValProbabilities.sort(null);
int highestBitPosition = getMetaData().getValues().length - 1;
if (containsMissingValues()) {
highestBitPosition--;
}
final double[] targetCountsSplitPartition = new double[2];
final double[] targetCountsSplitRemaining = new double[2];
final double[] binaryImpurityValues = new double[2];
final double[] binaryPartitionWeights = new double[2];
BigInteger partitionMask = BigInteger.ZERO;
double bestPartitionGain = Double.NEGATIVE_INFINITY;
BigInteger bestPartitionMask = null;
boolean isBestSplitValid = false;
double sumWeightsPartitionTotal = 0.0;
double sumWeightsPartitionFirstClass = 0.0;
boolean missingsGoLeft = false;
final double priorImpurity = useXGBoostMissingValueHandling ? targetPriors.getPriorImpurity() : impCriterion.getPartitionImpurity(subtractMissingClassCounts(targetPriors.getDistribution(), createMissingClassCountsTwoClass(missingWeight, missingWeightFirstClass)), totalWeight);
// we don't need to iterate over the full list because we always need some value on the other side
for (int i = 0; i < nomValProbabilities.size() - 1; i++) {
NomValProbabilityPair nomVal = nomValProbabilities.get(i);
sumWeightsPartitionTotal += nomVal.m_sumWeights;
sumWeightsPartitionFirstClass += nomVal.m_firstClassSumWeights;
partitionMask = partitionMask.or(nomVal.m_bitMask);
// check if split represented by currentSplitList is in the right branch
// by convention a split goes towards the right branch if the highest possible bit is set to 1
final boolean isRightBranch = partitionMask.testBit(highestBitPosition);
double gain;
boolean isValidSplit;
boolean tempMissingsGoLeft = true;
if (branchContainsMissingValues && useXGBoostMissingValueHandling) {
// send missing values both ways and take the better direction
// send missings left
targetCountsSplitPartition[0] = sumWeightsPartitionFirstClass + missingWeightFirstClass;
targetCountsSplitPartition[1] = sumWeightsPartitionTotal + missingWeight - targetCountsSplitPartition[0];
binaryPartitionWeights[1] = sumWeightsPartitionTotal + missingWeight;
// totalFirstClassWeight and totalWeight only include non missing values
targetCountsSplitRemaining[0] = totalFirstClassWeight - sumWeightsPartitionFirstClass;
targetCountsSplitRemaining[1] = totalWeight - sumWeightsPartitionTotal - targetCountsSplitRemaining[0];
binaryPartitionWeights[0] = totalWeight - sumWeightsPartitionTotal;
boolean isValidSplitLeft = binaryPartitionWeights[0] >= minChildSize && binaryPartitionWeights[1] >= minChildSize;
binaryImpurityValues[0] = impCriterion.getPartitionImpurity(targetCountsSplitRemaining, binaryPartitionWeights[0]);
binaryImpurityValues[1] = impCriterion.getPartitionImpurity(targetCountsSplitPartition, binaryPartitionWeights[1]);
double postSplitImpurity = impCriterion.getPostSplitImpurity(binaryImpurityValues, binaryPartitionWeights, totalWeight + missingWeight);
double gainLeft = impCriterion.getGain(priorImpurity, postSplitImpurity, binaryPartitionWeights, totalWeight + missingWeight);
// send missings right
targetCountsSplitPartition[0] = sumWeightsPartitionFirstClass;
targetCountsSplitPartition[1] = sumWeightsPartitionTotal - sumWeightsPartitionFirstClass;
binaryPartitionWeights[1] = sumWeightsPartitionTotal;
targetCountsSplitRemaining[0] = totalFirstClassWeight - sumWeightsPartitionFirstClass + missingWeightFirstClass;
targetCountsSplitRemaining[1] = totalWeight - sumWeightsPartitionTotal + missingWeight - targetCountsSplitRemaining[0];
binaryPartitionWeights[0] = totalWeight + missingWeight - sumWeightsPartitionTotal;
boolean isValidSplitRight = binaryPartitionWeights[0] >= minChildSize && binaryPartitionWeights[1] >= minChildSize;
binaryImpurityValues[0] = impCriterion.getPartitionImpurity(targetCountsSplitRemaining, binaryPartitionWeights[0]);
binaryImpurityValues[1] = impCriterion.getPartitionImpurity(targetCountsSplitPartition, binaryPartitionWeights[1]);
postSplitImpurity = impCriterion.getPostSplitImpurity(binaryImpurityValues, binaryPartitionWeights, totalWeight + missingWeight);
double gainRight = impCriterion.getGain(priorImpurity, postSplitImpurity, binaryPartitionWeights, totalWeight + missingWeight);
// decide which is better (better gain)
if (gainLeft >= gainRight) {
gain = gainLeft;
isValidSplit = isValidSplitLeft;
tempMissingsGoLeft = true;
} else {
gain = gainRight;
isValidSplit = isValidSplitRight;
tempMissingsGoLeft = false;
}
} else {
// assign weights to branches
targetCountsSplitPartition[0] = sumWeightsPartitionFirstClass;
targetCountsSplitPartition[1] = sumWeightsPartitionTotal - sumWeightsPartitionFirstClass;
binaryPartitionWeights[1] = sumWeightsPartitionTotal;
targetCountsSplitRemaining[0] = totalFirstClassWeight - sumWeightsPartitionFirstClass;
targetCountsSplitRemaining[1] = totalWeight - sumWeightsPartitionTotal - targetCountsSplitRemaining[0];
binaryPartitionWeights[0] = totalWeight - sumWeightsPartitionTotal;
isValidSplit = binaryPartitionWeights[0] >= minChildSize && binaryPartitionWeights[1] >= minChildSize;
binaryImpurityValues[0] = impCriterion.getPartitionImpurity(targetCountsSplitRemaining, binaryPartitionWeights[0]);
binaryImpurityValues[1] = impCriterion.getPartitionImpurity(targetCountsSplitPartition, binaryPartitionWeights[1]);
double postSplitImpurity = impCriterion.getPostSplitImpurity(binaryImpurityValues, binaryPartitionWeights, totalWeight);
gain = impCriterion.getGain(priorImpurity, postSplitImpurity, binaryPartitionWeights, totalWeight);
}
// use random tie breaker if gains are equal
boolean randomTieBreaker = gain == bestPartitionGain ? rd.nextInt(0, 1) == 1 : false;
// store if better than before or first valid split
if (gain > bestPartitionGain || (!isBestSplitValid && isValidSplit) || randomTieBreaker) {
if (isValidSplit || !isBestSplitValid) {
bestPartitionGain = gain;
bestPartitionMask = isRightBranch ? partitionMask : BigInteger.ZERO.setBit(highestBitPosition + 1).subtract(BigInteger.ONE).xor(partitionMask);
isBestSplitValid = isValidSplit;
// missingsGoLeft is only used later on if XGBoost Missing Value Handling is used
if (branchContainsMissingValues) {
// missingsGoLeft = isRightBranch;
missingsGoLeft = tempMissingsGoLeft;
} else {
// no missing values in this branch
// send missing values with the majority
missingsGoLeft = isRightBranch ? sumWeightsPartitionTotal < 0.5 * totalWeight : sumWeightsPartitionTotal >= 0.5 * totalWeight;
}
}
}
}
if (isBestSplitValid && bestPartitionGain > 0.0) {
if (useXGBoostMissingValueHandling) {
return new NominalBinarySplitCandidate(this, bestPartitionGain, bestPartitionMask, NO_MISSED_ROWS, missingsGoLeft ? NominalBinarySplitCandidate.MISSINGS_GO_LEFT : NominalBinarySplitCandidate.MISSINGS_GO_RIGHT);
}
return new NominalBinarySplitCandidate(this, bestPartitionGain, bestPartitionMask, getMissedRows(columnMemberships), NominalBinarySplitCandidate.NO_MISSINGS);
}
return null;
}
Aggregations