use of org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation in project knime-core by knime.
the class NominalAttributeColumnHelper method createMetaData.
/**
* {@inheritDoc}
*/
@Override
protected TreeNominalColumnMetaData createMetaData(final DataColumnSpec nominalColSpec) {
DataColumnDomain domain = nominalColSpec.getDomain();
CheckUtils.checkArgument(domain.hasValues(), "The data dictionary doesn't contain domain" + " information for column \"%s\".", nominalColSpec);
NominalValueRepresentation[] nomVals = NominalColumnHelperUtil.extractNomValReps(domain.getValues());
return new TreeNominalColumnMetaData(nominalColSpec.getName(), nomVals);
}
use of org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation in project knime-core by knime.
the class NominalTargetColumnHelper method createMetaData.
/**
* {@inheritDoc}
*/
@Override
protected TreeTargetNominalColumnMetaData createMetaData(final DataColumnSpec nominalColSpec) {
DataColumnDomain domain = nominalColSpec.getDomain();
CheckUtils.checkArgument(domain.hasValues(), "The target field \"%s\" in the data dictionary has no possible values assigned.", nominalColSpec);
NominalValueRepresentation[] nomVals = NominalColumnHelperUtil.extractNomValReps(domain.getValues());
return new TreeTargetNominalColumnMetaData(nominalColSpec.getName(), nomVals);
}
use of org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation in project knime-core by knime.
the class TreeNumericColumnData method calcBestSplitClassification.
@Override
public NumericSplitCandidate calcBestSplitClassification(final DataMemberships dataMemberships, final ClassificationPriors targetPriors, final TreeTargetNominalColumnData targetColumn, final RandomData rd) {
final TreeEnsembleLearnerConfiguration config = getConfiguration();
final NominalValueRepresentation[] targetVals = targetColumn.getMetaData().getValues();
final boolean useAverageSplitPoints = config.isUseAverageSplitPoints();
final int minChildNodeSize = config.getMinChildSize();
// distribution of target for each attribute value
final int targetCounts = targetVals.length;
final double[] targetCountsLeftOfSplit = new double[targetCounts];
final double[] targetCountsRightOfSplit = targetPriors.getDistribution().clone();
assert targetCountsRightOfSplit.length == targetCounts;
final double totalSumWeight = targetPriors.getNrRecords();
final IImpurity impurityCriterion = targetPriors.getImpurityCriterion();
final boolean useXGBoostMissingValueHandling = config.getMissingValueHandling() == MissingValueHandling.XGBoost;
// get columnMemberships
final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
// missing value handling
boolean branchContainsMissingValues = containsMissingValues();
boolean missingsGoLeft = true;
final int lengthNonMissing = getLengthNonMissing();
final double[] missingTargetCounts = new double[targetCounts];
int lastValidSplitPosition = -1;
double missingWeight = 0;
columnMemberships.goToLast();
do {
final int indexInColumn = columnMemberships.getIndexInColumn();
if (indexInColumn >= lengthNonMissing) {
final double weight = columnMemberships.getRowWeight();
final int classIdx = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
targetCountsRightOfSplit[classIdx] -= weight;
missingTargetCounts[classIdx] += weight;
missingWeight += weight;
} else {
if (lastValidSplitPosition < 0) {
lastValidSplitPosition = indexInColumn;
} else if ((getSorted(lastValidSplitPosition) - getSorted(indexInColumn)) >= EPSILON) {
break;
} else {
lastValidSplitPosition = indexInColumn;
}
}
} while (columnMemberships.previous());
// it is possible that the column contains missing values but in the current branch there are no missing values
branchContainsMissingValues = missingWeight > 0.0;
columnMemberships.reset();
double sumWeightsLeftOfSplit = 0.0;
double sumWeightsRightOfSplit = totalSumWeight - missingWeight;
final double priorImpurity = useXGBoostMissingValueHandling || !branchContainsMissingValues ? targetPriors.getPriorImpurity() : impurityCriterion.getPartitionImpurity(TreeNominalColumnData.subtractMissingClassCounts(targetPriors.getDistribution(), missingTargetCounts), sumWeightsRightOfSplit);
// all values in branch are missing
if (sumWeightsRightOfSplit == 0) {
// it is impossible to determine a split
return null;
}
double bestSplit = Double.NEGATIVE_INFINITY;
// gain for best split point, unnormalized (not using info gain ratio)
double bestGain = Double.NEGATIVE_INFINITY;
// gain for best split, normalized by attribute entropy when
// info gain ratio is used.
double bestGainValueForSplit = Double.NEGATIVE_INFINITY;
final double[] tempArray1 = new double[2];
double[] tempArray2 = new double[2];
double lastSeenValue = Double.NEGATIVE_INFINITY;
boolean mustTestOnNextValueChange = false;
boolean testSplitOnStart = true;
boolean firstIteration = true;
int lastSeenTarget = -1;
int indexInCol = -1;
// We iterate over the instances in the sample/branch instead of the whole data set
while (columnMemberships.next() && (indexInCol = columnMemberships.getIndexInColumn()) < lengthNonMissing) {
final double weight = columnMemberships.getRowWeight();
assert weight >= EPSILON : "Rows with zero row weight should never be seen!";
final double value = getSorted(indexInCol);
final int target = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
final boolean hasValueChanged = (value - lastSeenValue) >= EPSILON;
final boolean hasTargetChanged = lastSeenTarget != target || indexInCol == lastValidSplitPosition;
if (hasTargetChanged && !firstIteration) {
mustTestOnNextValueChange = true;
testSplitOnStart = false;
}
if (!firstIteration && hasValueChanged && (mustTestOnNextValueChange || testSplitOnStart) && sumWeightsLeftOfSplit >= minChildNodeSize && sumWeightsRightOfSplit >= minChildNodeSize) {
double postSplitImpurity;
boolean tempMissingsGoLeft = false;
// missing value handling
if (branchContainsMissingValues && useXGBoostMissingValueHandling) {
final double[] targetCountsLeftPlusMissing = new double[targetCounts];
final double[] targetCountsRightPlusMissing = new double[targetCounts];
for (int i = 0; i < targetCounts; i++) {
targetCountsLeftPlusMissing[i] = targetCountsLeftOfSplit[i] + missingTargetCounts[i];
targetCountsRightPlusMissing[i] = targetCountsRightOfSplit[i] + missingTargetCounts[i];
}
final double[][] temp = new double[2][2];
final double[] postSplitImpurities = new double[2];
// send all missing values left
tempArray1[0] = impurityCriterion.getPartitionImpurity(targetCountsLeftPlusMissing, sumWeightsLeftOfSplit + missingWeight);
tempArray1[1] = impurityCriterion.getPartitionImpurity(targetCountsRightOfSplit, sumWeightsRightOfSplit);
temp[0][0] = sumWeightsLeftOfSplit + missingWeight;
temp[0][1] = sumWeightsRightOfSplit;
postSplitImpurities[0] = impurityCriterion.getPostSplitImpurity(tempArray1, temp[0], totalSumWeight);
// send all missing values right
tempArray1[0] = impurityCriterion.getPartitionImpurity(targetCountsLeftOfSplit, sumWeightsLeftOfSplit);
tempArray1[1] = impurityCriterion.getPartitionImpurity(targetCountsRightPlusMissing, sumWeightsRightOfSplit + missingWeight);
temp[1][0] = sumWeightsLeftOfSplit;
temp[1][1] = sumWeightsRightOfSplit + missingWeight;
postSplitImpurities[1] = impurityCriterion.getPostSplitImpurity(tempArray1, temp[1], totalSumWeight);
// take better split
if (postSplitImpurities[0] < postSplitImpurities[1]) {
postSplitImpurity = postSplitImpurities[0];
tempArray2 = temp[0];
tempMissingsGoLeft = true;
// TODO random tie breaker
} else {
postSplitImpurity = postSplitImpurities[1];
tempArray2 = temp[1];
tempMissingsGoLeft = false;
}
} else {
tempArray1[0] = impurityCriterion.getPartitionImpurity(targetCountsLeftOfSplit, sumWeightsLeftOfSplit);
tempArray1[1] = impurityCriterion.getPartitionImpurity(targetCountsRightOfSplit, sumWeightsRightOfSplit);
tempArray2[0] = sumWeightsLeftOfSplit;
tempArray2[1] = sumWeightsRightOfSplit;
postSplitImpurity = impurityCriterion.getPostSplitImpurity(tempArray1, tempArray2, totalSumWeight);
}
if (postSplitImpurity < priorImpurity) {
// Use absolute gain (IG) for split calculation even
// if the split criterion is information gain ratio (IGR).
// IGR wouldn't work as it favors extreme unfair splits,
// i.e. 1:9999 would have an attribute entropy
// (IGR denominator) of
// 9999/10000*log(9999/10000) + 1/10000*log(1/10000)
// which is ~0.00148
double gain = (priorImpurity - postSplitImpurity);
boolean randomTieBreaker = gain == bestGain ? rd.nextInt(0, 1) == 1 : false;
if (gain > bestGain || randomTieBreaker) {
bestGainValueForSplit = impurityCriterion.getGain(priorImpurity, postSplitImpurity, tempArray2, totalSumWeight);
bestGain = gain;
bestSplit = useAverageSplitPoints ? getCenter(lastSeenValue, value) : lastSeenValue;
// Go with the majority if there are no missing values during training this is because we should
// still provide a missing direction for the case that there are missing values during prediction
missingsGoLeft = branchContainsMissingValues ? tempMissingsGoLeft : sumWeightsLeftOfSplit > sumWeightsRightOfSplit;
}
}
mustTestOnNextValueChange = false;
}
targetCountsLeftOfSplit[target] += weight;
sumWeightsLeftOfSplit += weight;
targetCountsRightOfSplit[target] -= weight;
sumWeightsRightOfSplit -= weight;
lastSeenTarget = target;
lastSeenValue = value;
firstIteration = false;
}
columnMemberships.reset();
if (bestGainValueForSplit < 0.0) {
// (see info gain ratio implementation)
return null;
}
if (useXGBoostMissingValueHandling) {
// return new NumericMissingSplitCandidate(this, bestSplit, bestGainValueForSplit, missingsGoLeft);
return new NumericSplitCandidate(this, bestSplit, bestGainValueForSplit, new BitSet(), missingsGoLeft ? NumericSplitCandidate.MISSINGS_GO_LEFT : NumericSplitCandidate.MISSINGS_GO_RIGHT);
}
return new NumericSplitCandidate(this, bestSplit, bestGainValueForSplit, getMissedRows(columnMemberships), NumericSplitCandidate.NO_MISSINGS);
}
use of org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation in project knime-core by knime.
the class RandomForestClassificationTreeNodeWidget method createTablePanel.
private JPanel createTablePanel(final float scale) {
TreeNodeClassification node = (TreeNodeClassification) getUserObject();
final float[] targetDistribution = node.getTargetDistribution();
double totalClassCount = 0.0;
for (double classCount : targetDistribution) {
totalClassCount += classCount;
}
JPanel p = new JPanel(new GridBagLayout());
GridBagConstraints c = new GridBagConstraints();
int gridwidth = 3;
c.fill = GridBagConstraints.HORIZONTAL;
c.anchor = GridBagConstraints.NORTHWEST;
int bw = Math.round(1 * scale);
c.insets = new Insets(bw, bw, bw, bw);
c.gridx = 0;
c.gridy = 0;
c.weightx = 1;
c.weighty = 1;
c.gridwidth = 1;
p.add(scaledLabel("Category", scale), c);
c.gridx++;
p.add(scaledLabel("% ", scale, SwingConstants.RIGHT), c);
c.gridx++;
p.add(scaledLabel("n ", scale, SwingConstants.RIGHT), c);
c.gridy++;
c.gridx = 0;
c.gridwidth = GridBagConstraints.REMAINDER;
p.add(new MyJSeparator(), c);
c.gridwidth = 1;
int majorityClassIndex = node.getMajorityClassIndex();
NominalValueRepresentation[] classNomVals = node.getTargetMetaData().getValues();
List<Double> classFreqList = new ArrayList<Double>();
for (int i = 0; i < targetDistribution.length; i++) {
JLabel classLabel = scaledLabel(classNomVals[i].getNominalValue(), scale);
c.gridy++;
c.gridx = 0;
p.add(classLabel, c);
c.gridx++;
double classFreq = targetDistribution[i] / totalClassCount;
classFreqList.add(classFreq);
p.add(scaledLabel(convertPercentage(classFreq), scale, SwingConstants.RIGHT), c);
c.gridx++;
final Float classCountValue = targetDistribution[i];
p.add(scaledLabel(convertCount(classCountValue), scale, SwingConstants.RIGHT), c);
if (i == majorityClassIndex) {
c.gridx = 0;
JComponent comp = new JPanel();
comp.setMinimumSize(classLabel.getPreferredSize());
comp.setPreferredSize(classLabel.getPreferredSize());
comp.setBackground(new Color(225, 225, 225));
c.gridwidth = gridwidth;
p.add(comp, c);
c.gridwidth = 1;
}
}
c.gridy++;
c.gridx = 0;
c.gridwidth = gridwidth;
p.add(new MyJSeparator(), c);
c.gridwidth = 1;
c.gridy++;
c.gridx = 0;
p.add(scaledLabel("Total", scale), c);
c.gridx++;
double nominator = 0.0;
TreeNodeClassification root = (TreeNodeClassification) getGraphView().getRootNode();
if (root != null) {
final float[] rootTargetDistribution = root.getTargetDistribution();
double rootTotalClassCount = 0.0;
for (double classCount : rootTargetDistribution) {
rootTotalClassCount += classCount;
}
nominator = rootTotalClassCount;
} else {
nominator = totalClassCount;
}
double coverage = totalClassCount / nominator;
p.add(scaledLabel(convertPercentage(coverage), scale, SwingConstants.RIGHT), c);
c.gridx++;
p.add(scaledLabel(convertCount(totalClassCount), scale, SwingConstants.RIGHT), c);
return p;
}
use of org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation in project knime-core by knime.
the class TreeNominalColumnData method calcBestSplitClassificationBinaryTwoClass.
private NominalBinarySplitCandidate calcBestSplitClassificationBinaryTwoClass(final ColumnMemberships columnMemberships, final ClassificationPriors targetPriors, final TreeTargetNominalColumnData targetColumn, final IImpurity impCriterion, final NominalValueRepresentation[] nomVals, final NominalValueRepresentation[] targetVals, final RandomData rd) {
if (targetColumn.getMetaData().getValues().length != 2) {
throw new IllegalArgumentException("This method can only be used for two class problems.");
}
final TreeEnsembleLearnerConfiguration config = getConfiguration();
final int minChildSize = config.getMinChildSize();
final boolean useXGBoostMissingValueHandling = config.getMissingValueHandling() == MissingValueHandling.XGBoost;
int start = 0;
final int firstClass = targetColumn.getMetaData().getValues()[0].getAssignedInteger();
double totalWeight = 0.0;
double totalFirstClassWeight = 0.0;
final ArrayList<NomValProbabilityPair> nomValProbabilities = new ArrayList<NomValProbabilityPair>();
if (!columnMemberships.next()) {
throw new IllegalStateException("The columnMemberships has not been reset or is empty.");
}
final int lengthNonMissing = containsMissingValues() ? nomVals.length - 1 : nomVals.length;
// final int attToConsider = useXGBoostMissingValueHandling ? nomVals.length : lengthNonMissing;
boolean branchContainsMissingValues = containsMissingValues();
// calculate probabilities for first class in each nominal value
for (int att = 0; att < /*attToConsider*/
lengthNonMissing; att++) {
int end = start + m_nominalValueCounts[att];
double attFirstClassWeight = 0;
double attWeight = 0;
boolean reachedEnd = false;
for (int index = columnMemberships.getIndexInColumn(); index < end; index = columnMemberships.getIndexInColumn()) {
double weight = columnMemberships.getRowWeight();
assert weight > EPSILON : "Instances in columnMemberships must have weights larger than EPSILON.";
final int instanceClass = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
if (instanceClass == firstClass) {
attFirstClassWeight += weight;
totalFirstClassWeight += weight;
}
attWeight += weight;
totalWeight += weight;
if (!columnMemberships.next()) {
// reached end of columnMemberships
reachedEnd = true;
if (att == nomVals.length - 1) {
// if the column contains no missing values, the last possible nominal value is
// not the missing value and therefore branchContainsMissingValues needs to be false
branchContainsMissingValues = branchContainsMissingValues && true;
}
break;
}
}
if (attWeight > 0) {
final double firstClassProbability = attFirstClassWeight / attWeight;
final NominalValueRepresentation nomVal = getMetaData().getValues()[att];
nomValProbabilities.add(new NomValProbabilityPair(nomVal, firstClassProbability, attWeight, attFirstClassWeight));
}
start = end;
if (reachedEnd) {
break;
}
}
// account for missing values and their weight
double missingWeight = 0.0;
double missingWeightFirstClass = 0.0;
// otherwise the current indexInColumn won't be larger than start
if (columnMemberships.getIndexInColumn() >= start) {
do {
final double recordWeight = columnMemberships.getRowWeight();
missingWeight += recordWeight;
final int recordClass = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
if (recordClass == firstClass) {
missingWeightFirstClass += recordWeight;
}
} while (columnMemberships.next());
}
if (missingWeight > EPSILON) {
branchContainsMissingValues = true;
}
nomValProbabilities.sort(null);
int highestBitPosition = getMetaData().getValues().length - 1;
if (containsMissingValues()) {
highestBitPosition--;
}
final double[] targetCountsSplitPartition = new double[2];
final double[] targetCountsSplitRemaining = new double[2];
final double[] binaryImpurityValues = new double[2];
final double[] binaryPartitionWeights = new double[2];
BigInteger partitionMask = BigInteger.ZERO;
double bestPartitionGain = Double.NEGATIVE_INFINITY;
BigInteger bestPartitionMask = null;
boolean isBestSplitValid = false;
double sumWeightsPartitionTotal = 0.0;
double sumWeightsPartitionFirstClass = 0.0;
boolean missingsGoLeft = false;
final double priorImpurity = useXGBoostMissingValueHandling ? targetPriors.getPriorImpurity() : impCriterion.getPartitionImpurity(subtractMissingClassCounts(targetPriors.getDistribution(), createMissingClassCountsTwoClass(missingWeight, missingWeightFirstClass)), totalWeight);
// we don't need to iterate over the full list because we always need some value on the other side
for (int i = 0; i < nomValProbabilities.size() - 1; i++) {
NomValProbabilityPair nomVal = nomValProbabilities.get(i);
sumWeightsPartitionTotal += nomVal.m_sumWeights;
sumWeightsPartitionFirstClass += nomVal.m_firstClassSumWeights;
partitionMask = partitionMask.or(nomVal.m_bitMask);
// check if split represented by currentSplitList is in the right branch
// by convention a split goes towards the right branch if the highest possible bit is set to 1
final boolean isRightBranch = partitionMask.testBit(highestBitPosition);
double gain;
boolean isValidSplit;
boolean tempMissingsGoLeft = true;
if (branchContainsMissingValues && useXGBoostMissingValueHandling) {
// send missing values both ways and take the better direction
// send missings left
targetCountsSplitPartition[0] = sumWeightsPartitionFirstClass + missingWeightFirstClass;
targetCountsSplitPartition[1] = sumWeightsPartitionTotal + missingWeight - targetCountsSplitPartition[0];
binaryPartitionWeights[1] = sumWeightsPartitionTotal + missingWeight;
// totalFirstClassWeight and totalWeight only include non missing values
targetCountsSplitRemaining[0] = totalFirstClassWeight - sumWeightsPartitionFirstClass;
targetCountsSplitRemaining[1] = totalWeight - sumWeightsPartitionTotal - targetCountsSplitRemaining[0];
binaryPartitionWeights[0] = totalWeight - sumWeightsPartitionTotal;
boolean isValidSplitLeft = binaryPartitionWeights[0] >= minChildSize && binaryPartitionWeights[1] >= minChildSize;
binaryImpurityValues[0] = impCriterion.getPartitionImpurity(targetCountsSplitRemaining, binaryPartitionWeights[0]);
binaryImpurityValues[1] = impCriterion.getPartitionImpurity(targetCountsSplitPartition, binaryPartitionWeights[1]);
double postSplitImpurity = impCriterion.getPostSplitImpurity(binaryImpurityValues, binaryPartitionWeights, totalWeight + missingWeight);
double gainLeft = impCriterion.getGain(priorImpurity, postSplitImpurity, binaryPartitionWeights, totalWeight + missingWeight);
// send missings right
targetCountsSplitPartition[0] = sumWeightsPartitionFirstClass;
targetCountsSplitPartition[1] = sumWeightsPartitionTotal - sumWeightsPartitionFirstClass;
binaryPartitionWeights[1] = sumWeightsPartitionTotal;
targetCountsSplitRemaining[0] = totalFirstClassWeight - sumWeightsPartitionFirstClass + missingWeightFirstClass;
targetCountsSplitRemaining[1] = totalWeight - sumWeightsPartitionTotal + missingWeight - targetCountsSplitRemaining[0];
binaryPartitionWeights[0] = totalWeight + missingWeight - sumWeightsPartitionTotal;
boolean isValidSplitRight = binaryPartitionWeights[0] >= minChildSize && binaryPartitionWeights[1] >= minChildSize;
binaryImpurityValues[0] = impCriterion.getPartitionImpurity(targetCountsSplitRemaining, binaryPartitionWeights[0]);
binaryImpurityValues[1] = impCriterion.getPartitionImpurity(targetCountsSplitPartition, binaryPartitionWeights[1]);
postSplitImpurity = impCriterion.getPostSplitImpurity(binaryImpurityValues, binaryPartitionWeights, totalWeight + missingWeight);
double gainRight = impCriterion.getGain(priorImpurity, postSplitImpurity, binaryPartitionWeights, totalWeight + missingWeight);
// decide which is better (better gain)
if (gainLeft >= gainRight) {
gain = gainLeft;
isValidSplit = isValidSplitLeft;
tempMissingsGoLeft = true;
} else {
gain = gainRight;
isValidSplit = isValidSplitRight;
tempMissingsGoLeft = false;
}
} else {
// assign weights to branches
targetCountsSplitPartition[0] = sumWeightsPartitionFirstClass;
targetCountsSplitPartition[1] = sumWeightsPartitionTotal - sumWeightsPartitionFirstClass;
binaryPartitionWeights[1] = sumWeightsPartitionTotal;
targetCountsSplitRemaining[0] = totalFirstClassWeight - sumWeightsPartitionFirstClass;
targetCountsSplitRemaining[1] = totalWeight - sumWeightsPartitionTotal - targetCountsSplitRemaining[0];
binaryPartitionWeights[0] = totalWeight - sumWeightsPartitionTotal;
isValidSplit = binaryPartitionWeights[0] >= minChildSize && binaryPartitionWeights[1] >= minChildSize;
binaryImpurityValues[0] = impCriterion.getPartitionImpurity(targetCountsSplitRemaining, binaryPartitionWeights[0]);
binaryImpurityValues[1] = impCriterion.getPartitionImpurity(targetCountsSplitPartition, binaryPartitionWeights[1]);
double postSplitImpurity = impCriterion.getPostSplitImpurity(binaryImpurityValues, binaryPartitionWeights, totalWeight);
gain = impCriterion.getGain(priorImpurity, postSplitImpurity, binaryPartitionWeights, totalWeight);
}
// use random tie breaker if gains are equal
boolean randomTieBreaker = gain == bestPartitionGain ? rd.nextInt(0, 1) == 1 : false;
// store if better than before or first valid split
if (gain > bestPartitionGain || (!isBestSplitValid && isValidSplit) || randomTieBreaker) {
if (isValidSplit || !isBestSplitValid) {
bestPartitionGain = gain;
bestPartitionMask = isRightBranch ? partitionMask : BigInteger.ZERO.setBit(highestBitPosition + 1).subtract(BigInteger.ONE).xor(partitionMask);
isBestSplitValid = isValidSplit;
// missingsGoLeft is only used later on if XGBoost Missing Value Handling is used
if (branchContainsMissingValues) {
// missingsGoLeft = isRightBranch;
missingsGoLeft = tempMissingsGoLeft;
} else {
// no missing values in this branch
// send missing values with the majority
missingsGoLeft = isRightBranch ? sumWeightsPartitionTotal < 0.5 * totalWeight : sumWeightsPartitionTotal >= 0.5 * totalWeight;
}
}
}
}
if (isBestSplitValid && bestPartitionGain > 0.0) {
if (useXGBoostMissingValueHandling) {
return new NominalBinarySplitCandidate(this, bestPartitionGain, bestPartitionMask, NO_MISSED_ROWS, missingsGoLeft ? NominalBinarySplitCandidate.MISSINGS_GO_LEFT : NominalBinarySplitCandidate.MISSINGS_GO_RIGHT);
}
return new NominalBinarySplitCandidate(this, bestPartitionGain, bestPartitionMask, getMissedRows(columnMemberships), NominalBinarySplitCandidate.NO_MISSINGS);
}
return null;
}
Aggregations