use of org.knime.base.node.mine.treeensemble2.learner.IImpurity in project knime-core by knime.
the class TreeNominalColumnData method calcBestSplitClassificationBinaryPCA.
* Implements the approach proposed by Coppersmith et al. (1999) in their paper
* "Partitioning Nominal Attributes in Decision Trees"
* @param membershipController
* @param rowWeights
* @param targetPriors
* @param targetColumn
* @param impCriterion
* @param nomVals
* @param targetVals
* @param originalIndexInColumnList
* @return the best binary split candidate or null if there is no valid split with positive gain
private NominalBinarySplitCandidate calcBestSplitClassificationBinaryPCA(final ColumnMemberships columnMemberships, final ClassificationPriors targetPriors, final TreeTargetNominalColumnData targetColumn, final IImpurity impCriterion, final NominalValueRepresentation[] nomVals, final NominalValueRepresentation[] targetVals, final RandomData rd) {
final TreeEnsembleLearnerConfiguration config = getConfiguration();
final int minChildSize = config.getMinChildSize();
final boolean useXGBoostMissingValueHandling = config.getMissingValueHandling() == MissingValueHandling.XGBoost;
// The algorithm combines attribute values with the same class probabilities into a single attribute
// therefore it is necessary to track the known classProbabilities
final LinkedHashMap<ClassProbabilityVector, CombinedAttributeValues> combinedAttValsMap = new LinkedHashMap<ClassProbabilityVector, CombinedAttributeValues>();;
double totalWeight = 0.0;
boolean branchContainsMissingValues = containsMissingValues();
int start = 0;
final int lengthNonMissing = containsMissingValues() ? nomVals.length - 1 : nomVals.length;
final int attToConsider = useXGBoostMissingValueHandling ? nomVals.length : lengthNonMissing;
for (int att = 0; att < lengthNonMissing; /*attToConsider*/
att++) {
int end = start + m_nominalValueCounts[att];
double attWeight = 0.0;
final double[] classFrequencies = new double[targetVals.length];
boolean reachedEnd = false;
for (int index = columnMemberships.getIndexInColumn(); index < end; index = columnMemberships.getIndexInColumn()) {
double weight = columnMemberships.getRowWeight();
assert weight > EPSILON : "Instances in columnMemberships must have weights larger than EPSILON.";
int instanceClass = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
classFrequencies[instanceClass] += weight;
attWeight += weight;
totalWeight += weight;
if (! {
// reached end of columnMemberships
reachedEnd = true;
if (att == nomVals.length - 1) {
// if the column contains no missing values, the last possible nominal value is
// not the missing value and therefore branchContainsMissingValues needs to be false
branchContainsMissingValues = branchContainsMissingValues && true;
start = end;
if (attWeight < EPSILON) {
// attribute value did not occur in this branch or sample
final double[] classProbabilities = new double[targetVals.length];
for (int i = 0; i < classProbabilities.length; i++) {
classProbabilities[i] = truncateDouble(8, classFrequencies[i] / attWeight);
CombinedAttributeValues attVal = new CombinedAttributeValues(classFrequencies, classProbabilities, attWeight, nomVals[att]);
ClassProbabilityVector classProbabilityVector = new ClassProbabilityVector(classProbabilities);
CombinedAttributeValues knownAttVal = combinedAttValsMap.get(classProbabilityVector);
if (knownAttVal == null) {
combinedAttValsMap.put(classProbabilityVector, attVal);
} else {
if (reachedEnd) {
// account for missing values and their weight
double missingWeight = 0.0;
double[] missingClassCounts = null;
// otherwise the current indexInColumn won't be larger than start
if (columnMemberships.getIndexInColumn() >= start) {
missingClassCounts = new double[targetVals.length];
do {
final double recordWeight = columnMemberships.getRowWeight();
final int recordClass = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
missingWeight += recordWeight;
missingClassCounts[recordClass] += recordWeight;
} while (;
if (missingWeight > EPSILON) {
branchContainsMissingValues = true;
} else {
branchContainsMissingValues = false;
ArrayList<CombinedAttributeValues> attValList = Lists.newArrayList(combinedAttValsMap.values());
CombinedAttributeValues[] attVals = combinedAttValsMap.values().toArray(new CombinedAttributeValues[combinedAttValsMap.size()]);
attVals = BinaryNominalSplitsPCA.calculatePCAOrdering(attVals, totalWeight, targetVals.length);
// EigenDecomposition failed
if (attVals == null) {
return null;
// Start searching for split candidates
final int highestBitPosition = containsMissingValues() ? nomVals.length - 2 : nomVals.length - 1;
final double[] binaryImpurityValues = new double[2];
final double[] binaryPartitionWeights = new double[2];
double sumRemainingWeights = totalWeight;
double sumCurrPartitionWeight = 0.0;
RealVector targetFrequenciesCurrentPartition = MatrixUtils.createRealVector(new double[targetVals.length]);
RealVector targetFrequenciesRemaining = MatrixUtils.createRealVector(new double[targetVals.length]);
for (CombinedAttributeValues attVal : attValList) {
targetFrequenciesRemaining = targetFrequenciesRemaining.add(attVal.m_classFrequencyVector);
BigInteger currPartitionBitMask = BigInteger.ZERO;
double bestPartitionGain = Double.NEGATIVE_INFINITY;
BigInteger bestPartitionMask = null;
boolean isBestSplitValid = false;
boolean missingsGoLeft = false;
final double priorImpurity = useXGBoostMissingValueHandling ? targetPriors.getPriorImpurity() : impCriterion.getPartitionImpurity(subtractMissingClassCounts(targetPriors.getDistribution(), missingClassCounts), totalWeight);
// no need to iterate over full list because at least one value must remain on the other side of the split
for (int i = 0; i < attVals.length - 1; i++) {
CombinedAttributeValues currAttVal = attVals[i];
sumCurrPartitionWeight += currAttVal.m_totalWeight;
sumRemainingWeights -= currAttVal.m_totalWeight;
assert sumCurrPartitionWeight + sumRemainingWeights == totalWeight : "The weights of the partitions do not sum up to the total weight.";
targetFrequenciesCurrentPartition = targetFrequenciesCurrentPartition.add(currAttVal.m_classFrequencyVector);
targetFrequenciesRemaining = targetFrequenciesRemaining.subtract(currAttVal.m_classFrequencyVector);
currPartitionBitMask = currPartitionBitMask.or(currAttVal.m_bitMask);
boolean partitionIsRightBranch = currPartitionBitMask.testBit(highestBitPosition);
boolean isValidSplit;
double gain;
boolean tempMissingsGoLeft = true;
if (branchContainsMissingValues && useXGBoostMissingValueHandling) {
// send missing values with partition
boolean isValidSplitFirst = sumCurrPartitionWeight + missingWeight >= minChildSize && sumRemainingWeights >= minChildSize;
binaryImpurityValues[0] = impCriterion.getPartitionImpurity(addMissingClassCounts(targetFrequenciesCurrentPartition.toArray(), missingClassCounts), sumCurrPartitionWeight + missingWeight);
binaryImpurityValues[1] = impCriterion.getPartitionImpurity(targetFrequenciesRemaining.toArray(), sumRemainingWeights);
binaryPartitionWeights[0] = sumCurrPartitionWeight + missingWeight;
binaryPartitionWeights[1] = sumRemainingWeights;
double postSplitImpurity = impCriterion.getPostSplitImpurity(binaryImpurityValues, binaryPartitionWeights, totalWeight + missingWeight);
double gainFirst = impCriterion.getGain(priorImpurity, postSplitImpurity, binaryPartitionWeights, totalWeight + missingWeight);
// send missing values with remaining
boolean isValidSplitSecond = sumCurrPartitionWeight >= minChildSize && sumRemainingWeights + missingWeight >= minChildSize;
binaryImpurityValues[0] = impCriterion.getPartitionImpurity(targetFrequenciesCurrentPartition.toArray(), sumCurrPartitionWeight);
binaryImpurityValues[1] = impCriterion.getPartitionImpurity(addMissingClassCounts(targetFrequenciesRemaining.toArray(), missingClassCounts), sumRemainingWeights + missingWeight);
binaryPartitionWeights[0] = sumCurrPartitionWeight;
binaryPartitionWeights[1] = sumRemainingWeights + missingWeight;
postSplitImpurity = impCriterion.getPostSplitImpurity(binaryImpurityValues, binaryPartitionWeights, totalWeight + missingWeight);
double gainSecond = impCriterion.getGain(priorImpurity, postSplitImpurity, binaryPartitionWeights, totalWeight + missingWeight);
// choose alternative with better gain
if (gainFirst >= gainSecond) {
gain = gainFirst;
isValidSplit = isValidSplitFirst;
tempMissingsGoLeft = !partitionIsRightBranch;
} else {
gain = gainSecond;
isValidSplit = isValidSplitSecond;
tempMissingsGoLeft = partitionIsRightBranch;
} else {
// TODO if invalid splits should not be considered skip partition
isValidSplit = sumCurrPartitionWeight >= minChildSize && sumRemainingWeights >= minChildSize;
binaryImpurityValues[0] = impCriterion.getPartitionImpurity(targetFrequenciesCurrentPartition.toArray(), sumCurrPartitionWeight);
binaryImpurityValues[1] = impCriterion.getPartitionImpurity(targetFrequenciesRemaining.toArray(), sumRemainingWeights);
binaryPartitionWeights[0] = sumCurrPartitionWeight;
binaryPartitionWeights[1] = sumRemainingWeights;
double postSplitImpurity = impCriterion.getPostSplitImpurity(binaryImpurityValues, binaryPartitionWeights, totalWeight);
gain = impCriterion.getGain(priorImpurity, postSplitImpurity, binaryPartitionWeights, totalWeight);
// use random tie breaker if gains are equal
boolean randomTieBreaker = gain == bestPartitionGain ? rd.nextInt(0, 1) == 1 : false;
// store if better than before or first valid split
if (gain > bestPartitionGain || (!isBestSplitValid && isValidSplit) || randomTieBreaker) {
if (isValidSplit || !isBestSplitValid) {
bestPartitionGain = gain;
bestPartitionMask = partitionIsRightBranch ? currPartitionBitMask : BigInteger.ZERO.setBit(highestBitPosition + 1).subtract(BigInteger.ONE).xor(currPartitionBitMask);
isBestSplitValid = isValidSplit;
if (branchContainsMissingValues) {
missingsGoLeft = tempMissingsGoLeft;
// missing values are encountered during the search for the best split
// missingsGoLeft = partitionIsRightBranch;
} else {
// no missing values were encountered during the search for the best split
// missing values should be sent with the majority
missingsGoLeft = partitionIsRightBranch ? sumCurrPartitionWeight < sumRemainingWeights : sumCurrPartitionWeight >= sumRemainingWeights;
if (isBestSplitValid && bestPartitionGain > 0.0) {
if (useXGBoostMissingValueHandling) {
return new NominalBinarySplitCandidate(this, bestPartitionGain, bestPartitionMask, NO_MISSED_ROWS, missingsGoLeft ? NominalBinarySplitCandidate.MISSINGS_GO_LEFT : NominalBinarySplitCandidate.MISSINGS_GO_RIGHT);
return new NominalBinarySplitCandidate(this, bestPartitionGain, bestPartitionMask, getMissedRows(columnMemberships), NominalBinarySplitCandidate.NO_MISSINGS);
return null;
use of org.knime.base.node.mine.treeensemble2.learner.IImpurity in project knime-core by knime.
the class TreeNominalColumnData method calcBestSplitClassificationBinary.
NominalBinarySplitCandidate calcBestSplitClassificationBinary(final ColumnMemberships columnMemberships, final ClassificationPriors targetPriors, final TreeTargetNominalColumnData targetColumn, final IImpurity impCriterion, final NominalValueRepresentation[] nomVals, final NominalValueRepresentation[] targetVals, final RandomData rd) {
if (nomVals.length <= 1) {
return null;
final int minChildSize = getConfiguration().getMinChildSize();
final int lengthNonMissing = containsMissingValues() ? nomVals.length - 1 : nomVals.length;
// distribution of target for each attribute value
final double[][] targetCountsSplitPerAttribute = new double[lengthNonMissing][targetVals.length];
// number of valid records for each attribute value
final double[] attWeights = new double[lengthNonMissing];
// number (sum) of total valid values
double totalWeight = 0.0;
int start = 0;;
for (int att = 0; att < lengthNonMissing; att++) {
final int end = start + m_nominalValueCounts[att];
double currentAttValWeight = 0.0;
for (int index = columnMemberships.getIndexInColumn(); index < end;, index = columnMemberships.getIndexInColumn()) {
final double weight = columnMemberships.getRowWeight();
assert weight > EPSILON : "The usage of datamemberships should ensure that no rows with zero weight are encountered";
int target = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
targetCountsSplitPerAttribute[att][target] += weight;
currentAttValWeight += weight;
totalWeight += currentAttValWeight;
attWeights[att] = currentAttValWeight;
start = end;
BinarySplitEnumeration splitEnumeration;
if (nomVals.length <= 10) {
splitEnumeration = new FullBinarySplitEnumeration(nomVals.length);
} else {
int maxSearch = (1 << 10 - 2);
splitEnumeration = new RandomBinarySplitEnumeration(nomVals.length, maxSearch, rd);
BigInteger bestPartitionMask = null;
boolean isBestSplitValid = false;
double bestPartitionGain = Double.NEGATIVE_INFINITY;
final double[] targetCountsSplitLeft = new double[targetVals.length];
final double[] targetCountsSplitRight = new double[targetVals.length];
final double[] binaryImpurityValues = new double[2];
final double[] binaryPartitionWeights = new double[2];
do {
Arrays.fill(targetCountsSplitLeft, 0.0);
Arrays.fill(targetCountsSplitRight, 0.0);
double weightLeft = 0.0;
double weightRight = 0.0;
for (int i = 0; i < nomVals.length; i++) {
final boolean isAttributeInRightBranch = splitEnumeration.isInRightBranch(i);
double[] targetCountsCurrentAttribute = targetCountsSplitPerAttribute[i];
for (int targetVal = 0; targetVal < targetVals.length; targetVal++) {
if (isAttributeInRightBranch) {
targetCountsSplitRight[targetVal] += targetCountsCurrentAttribute[targetVal];
} else {
targetCountsSplitLeft[targetVal] += targetCountsCurrentAttribute[targetVal];
if (isAttributeInRightBranch) {
weightRight += attWeights[i];
} else {
weightLeft += attWeights[i];
binaryPartitionWeights[0] = weightRight;
binaryPartitionWeights[1] = weightLeft;
boolean isValidSplit = weightRight >= minChildSize && weightLeft >= minChildSize;
binaryImpurityValues[0] = impCriterion.getPartitionImpurity(targetCountsSplitRight, weightRight);
binaryImpurityValues[1] = impCriterion.getPartitionImpurity(targetCountsSplitLeft, weightLeft);
double postSplitImpurity = impCriterion.getPostSplitImpurity(binaryImpurityValues, binaryPartitionWeights, totalWeight);
double gain = impCriterion.getGain(targetPriors.getPriorImpurity(), postSplitImpurity, binaryPartitionWeights, totalWeight);
// use random tie breaker if gains are equal
boolean randomTieBreaker = gain == bestPartitionGain ? rd.nextInt(0, 1) == 1 : false;
// store if better than before or first valid split
if (gain > bestPartitionGain || (!isBestSplitValid && isValidSplit) || randomTieBreaker) {
if (isValidSplit || !isBestSplitValid) {
bestPartitionGain = gain;
bestPartitionMask = splitEnumeration.getValueMask();
isBestSplitValid = isValidSplit;
} while (;
if (bestPartitionGain > 0.0) {
return new NominalBinarySplitCandidate(this, bestPartitionGain, bestPartitionMask, getMissedRows(columnMemberships), NominalBinarySplitCandidate.NO_MISSINGS);
return null;
use of org.knime.base.node.mine.treeensemble2.learner.IImpurity in project knime-core by knime.
the class TreeNominalColumnData method calcBestSplitClassification.
* {@inheritDoc}
public SplitCandidate calcBestSplitClassification(final DataMemberships dataMemberships, final ClassificationPriors targetPriors, final TreeTargetNominalColumnData targetColumn, final RandomData rd) {
final NominalValueRepresentation[] targetVals = targetColumn.getMetaData().getValues();
IImpurity impCriterion = targetPriors.getImpurityCriterion();
// distribution of target for each attribute value
final NominalValueRepresentation[] nomVals = getMetaData().getValues();
final boolean useBinaryNominalSplits = getConfiguration().isUseBinaryNominalSplits();
final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
if (useBinaryNominalSplits) {
if (targetVals.length == 2) {
return calcBestSplitClassificationBinaryTwoClass(columnMemberships, targetPriors, targetColumn, impCriterion, nomVals, targetVals, rd);
} else {
return calcBestSplitClassificationBinaryPCA(columnMemberships, targetPriors, targetColumn, impCriterion, nomVals, targetVals, rd);
// return calcBestSplitClassificationBinary(membershipController, rowWeights, targetPriors, targetColumn,
// impCriterion, nomVals, targetVals, originalIndexInColumnList, rd);
} else {
return calcBestSplitClassificationMultiway(columnMemberships, targetPriors, targetColumn, impCriterion, nomVals, targetVals, rd);
use of org.knime.base.node.mine.treeensemble2.learner.IImpurity in project knime-core by knime.
the class TreeBitVectorColumnData method calcBestSplitClassification.
* {@inheritDoc}
public SplitCandidate calcBestSplitClassification(final DataMemberships dataMemberships, final ClassificationPriors targetPriors, final TreeTargetNominalColumnData targetColumn, final RandomData rd) {
final NominalValueRepresentation[] targetVals = targetColumn.getMetaData().getValues();
final IImpurity impurityCriterion = targetPriors.getImpurityCriterion();
final int minChildSize = getConfiguration().getMinChildSize();
// distribution of target for On ('1') and Off ('0') bits
final double[] onTargetWeights = new double[targetVals.length];
final double[] offTargetWeights = new double[targetVals.length];
double onWeights = 0.0;
double offWeights = 0.0;
final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
while ( {
final double weight = columnMemberships.getRowWeight();
if (weight < EPSILON) {
// ignore record: not in current branch or not in sample
assert false : "This code should never be reached!";
} else {
final int target = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
if (m_columnBitSet.get(columnMemberships.getIndexInColumn())) {
onWeights += weight;
onTargetWeights[target] += weight;
} else {
offWeights += weight;
offTargetWeights[target] += weight;
if (onWeights < minChildSize || offWeights < minChildSize) {
return null;
final double weightSum = onWeights + offWeights;
final double onImpurity = impurityCriterion.getPartitionImpurity(onTargetWeights, onWeights);
final double offImpurity = impurityCriterion.getPartitionImpurity(offTargetWeights, offWeights);
final double[] partitionWeights = new double[] { onWeights, offWeights };
final double postSplitImpurity = impurityCriterion.getPostSplitImpurity(new double[] { onImpurity, offImpurity }, partitionWeights, weightSum);
final double gainValue = impurityCriterion.getGain(targetPriors.getPriorImpurity(), postSplitImpurity, partitionWeights, weightSum);
return new BitSplitCandidate(this, gainValue);
use of org.knime.base.node.mine.treeensemble2.learner.IImpurity in project knime-core by knime.
the class TreeNumericColumnData method calcBestSplitClassification.
public NumericSplitCandidate calcBestSplitClassification(final DataMemberships dataMemberships, final ClassificationPriors targetPriors, final TreeTargetNominalColumnData targetColumn, final RandomData rd) {
final TreeEnsembleLearnerConfiguration config = getConfiguration();
final NominalValueRepresentation[] targetVals = targetColumn.getMetaData().getValues();
final boolean useAverageSplitPoints = config.isUseAverageSplitPoints();
final int minChildNodeSize = config.getMinChildSize();
// distribution of target for each attribute value
final int targetCounts = targetVals.length;
final double[] targetCountsLeftOfSplit = new double[targetCounts];
final double[] targetCountsRightOfSplit = targetPriors.getDistribution().clone();
assert targetCountsRightOfSplit.length == targetCounts;
final double totalSumWeight = targetPriors.getNrRecords();
final IImpurity impurityCriterion = targetPriors.getImpurityCriterion();
final boolean useXGBoostMissingValueHandling = config.getMissingValueHandling() == MissingValueHandling.XGBoost;
// get columnMemberships
final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
// missing value handling
boolean branchContainsMissingValues = containsMissingValues();
boolean missingsGoLeft = true;
final int lengthNonMissing = getLengthNonMissing();
final double[] missingTargetCounts = new double[targetCounts];
int lastValidSplitPosition = -1;
double missingWeight = 0;
do {
final int indexInColumn = columnMemberships.getIndexInColumn();
if (indexInColumn >= lengthNonMissing) {
final double weight = columnMemberships.getRowWeight();
final int classIdx = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
targetCountsRightOfSplit[classIdx] -= weight;
missingTargetCounts[classIdx] += weight;
missingWeight += weight;
} else {
if (lastValidSplitPosition < 0) {
lastValidSplitPosition = indexInColumn;
} else if ((getSorted(lastValidSplitPosition) - getSorted(indexInColumn)) >= EPSILON) {
} else {
lastValidSplitPosition = indexInColumn;
} while (columnMemberships.previous());
// it is possible that the column contains missing values but in the current branch there are no missing values
branchContainsMissingValues = missingWeight > 0.0;
double sumWeightsLeftOfSplit = 0.0;
double sumWeightsRightOfSplit = totalSumWeight - missingWeight;
final double priorImpurity = useXGBoostMissingValueHandling || !branchContainsMissingValues ? targetPriors.getPriorImpurity() : impurityCriterion.getPartitionImpurity(TreeNominalColumnData.subtractMissingClassCounts(targetPriors.getDistribution(), missingTargetCounts), sumWeightsRightOfSplit);
// all values in branch are missing
if (sumWeightsRightOfSplit == 0) {
// it is impossible to determine a split
return null;
double bestSplit = Double.NEGATIVE_INFINITY;
// gain for best split point, unnormalized (not using info gain ratio)
double bestGain = Double.NEGATIVE_INFINITY;
// gain for best split, normalized by attribute entropy when
// info gain ratio is used.
double bestGainValueForSplit = Double.NEGATIVE_INFINITY;
final double[] tempArray1 = new double[2];
double[] tempArray2 = new double[2];
double lastSeenValue = Double.NEGATIVE_INFINITY;
boolean mustTestOnNextValueChange = false;
boolean testSplitOnStart = true;
boolean firstIteration = true;
int lastSeenTarget = -1;
int indexInCol = -1;
// We iterate over the instances in the sample/branch instead of the whole data set
while ( && (indexInCol = columnMemberships.getIndexInColumn()) < lengthNonMissing) {
final double weight = columnMemberships.getRowWeight();
assert weight >= EPSILON : "Rows with zero row weight should never be seen!";
final double value = getSorted(indexInCol);
final int target = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
final boolean hasValueChanged = (value - lastSeenValue) >= EPSILON;
final boolean hasTargetChanged = lastSeenTarget != target || indexInCol == lastValidSplitPosition;
if (hasTargetChanged && !firstIteration) {
mustTestOnNextValueChange = true;
testSplitOnStart = false;
if (!firstIteration && hasValueChanged && (mustTestOnNextValueChange || testSplitOnStart) && sumWeightsLeftOfSplit >= minChildNodeSize && sumWeightsRightOfSplit >= minChildNodeSize) {
double postSplitImpurity;
boolean tempMissingsGoLeft = false;
// missing value handling
if (branchContainsMissingValues && useXGBoostMissingValueHandling) {
final double[] targetCountsLeftPlusMissing = new double[targetCounts];
final double[] targetCountsRightPlusMissing = new double[targetCounts];
for (int i = 0; i < targetCounts; i++) {
targetCountsLeftPlusMissing[i] = targetCountsLeftOfSplit[i] + missingTargetCounts[i];
targetCountsRightPlusMissing[i] = targetCountsRightOfSplit[i] + missingTargetCounts[i];
final double[][] temp = new double[2][2];
final double[] postSplitImpurities = new double[2];
// send all missing values left
tempArray1[0] = impurityCriterion.getPartitionImpurity(targetCountsLeftPlusMissing, sumWeightsLeftOfSplit + missingWeight);
tempArray1[1] = impurityCriterion.getPartitionImpurity(targetCountsRightOfSplit, sumWeightsRightOfSplit);
temp[0][0] = sumWeightsLeftOfSplit + missingWeight;
temp[0][1] = sumWeightsRightOfSplit;
postSplitImpurities[0] = impurityCriterion.getPostSplitImpurity(tempArray1, temp[0], totalSumWeight);
// send all missing values right
tempArray1[0] = impurityCriterion.getPartitionImpurity(targetCountsLeftOfSplit, sumWeightsLeftOfSplit);
tempArray1[1] = impurityCriterion.getPartitionImpurity(targetCountsRightPlusMissing, sumWeightsRightOfSplit + missingWeight);
temp[1][0] = sumWeightsLeftOfSplit;
temp[1][1] = sumWeightsRightOfSplit + missingWeight;
postSplitImpurities[1] = impurityCriterion.getPostSplitImpurity(tempArray1, temp[1], totalSumWeight);
// take better split
if (postSplitImpurities[0] < postSplitImpurities[1]) {
postSplitImpurity = postSplitImpurities[0];
tempArray2 = temp[0];
tempMissingsGoLeft = true;
// TODO random tie breaker
} else {
postSplitImpurity = postSplitImpurities[1];
tempArray2 = temp[1];
tempMissingsGoLeft = false;
} else {
tempArray1[0] = impurityCriterion.getPartitionImpurity(targetCountsLeftOfSplit, sumWeightsLeftOfSplit);
tempArray1[1] = impurityCriterion.getPartitionImpurity(targetCountsRightOfSplit, sumWeightsRightOfSplit);
tempArray2[0] = sumWeightsLeftOfSplit;
tempArray2[1] = sumWeightsRightOfSplit;
postSplitImpurity = impurityCriterion.getPostSplitImpurity(tempArray1, tempArray2, totalSumWeight);
if (postSplitImpurity < priorImpurity) {
// Use absolute gain (IG) for split calculation even
// if the split criterion is information gain ratio (IGR).
// IGR wouldn't work as it favors extreme unfair splits,
// i.e. 1:9999 would have an attribute entropy
// (IGR denominator) of
// 9999/10000*log(9999/10000) + 1/10000*log(1/10000)
// which is ~0.00148
double gain = (priorImpurity - postSplitImpurity);
boolean randomTieBreaker = gain == bestGain ? rd.nextInt(0, 1) == 1 : false;
if (gain > bestGain || randomTieBreaker) {
bestGainValueForSplit = impurityCriterion.getGain(priorImpurity, postSplitImpurity, tempArray2, totalSumWeight);
bestGain = gain;
bestSplit = useAverageSplitPoints ? getCenter(lastSeenValue, value) : lastSeenValue;
// Go with the majority if there are no missing values during training this is because we should
// still provide a missing direction for the case that there are missing values during prediction
missingsGoLeft = branchContainsMissingValues ? tempMissingsGoLeft : sumWeightsLeftOfSplit > sumWeightsRightOfSplit;
mustTestOnNextValueChange = false;
targetCountsLeftOfSplit[target] += weight;
sumWeightsLeftOfSplit += weight;
targetCountsRightOfSplit[target] -= weight;
sumWeightsRightOfSplit -= weight;
lastSeenTarget = target;
lastSeenValue = value;
firstIteration = false;
if (bestGainValueForSplit < 0.0) {
// (see info gain ratio implementation)
return null;
if (useXGBoostMissingValueHandling) {
// return new NumericMissingSplitCandidate(this, bestSplit, bestGainValueForSplit, missingsGoLeft);
return new NumericSplitCandidate(this, bestSplit, bestGainValueForSplit, new BitSet(), missingsGoLeft ? NumericSplitCandidate.MISSINGS_GO_LEFT : NumericSplitCandidate.MISSINGS_GO_RIGHT);
return new NumericSplitCandidate(this, bestSplit, bestGainValueForSplit, getMissedRows(columnMemberships), NumericSplitCandidate.NO_MISSINGS);