use of org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData in project knime-core by knime.
the class TreeNumericColumnData method calcBestSplitClassification.
@Override
public NumericSplitCandidate calcBestSplitClassification(final DataMemberships dataMemberships, final ClassificationPriors targetPriors, final TreeTargetNominalColumnData targetColumn, final RandomData rd) {
final TreeEnsembleLearnerConfiguration config = getConfiguration();
final NominalValueRepresentation[] targetVals = targetColumn.getMetaData().getValues();
final boolean useAverageSplitPoints = config.isUseAverageSplitPoints();
final int minChildNodeSize = config.getMinChildSize();
// distribution of target for each attribute value
final int targetCounts = targetVals.length;
final double[] targetCountsLeftOfSplit = new double[targetCounts];
final double[] targetCountsRightOfSplit = targetPriors.getDistribution().clone();
assert targetCountsRightOfSplit.length == targetCounts;
final double totalSumWeight = targetPriors.getNrRecords();
final IImpurity impurityCriterion = targetPriors.getImpurityCriterion();
final boolean useXGBoostMissingValueHandling = config.getMissingValueHandling() == MissingValueHandling.XGBoost;
// get columnMemberships
final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
// missing value handling
boolean branchContainsMissingValues = containsMissingValues();
boolean missingsGoLeft = true;
final int lengthNonMissing = getLengthNonMissing();
final double[] missingTargetCounts = new double[targetCounts];
int lastValidSplitPosition = -1;
double missingWeight = 0;
columnMemberships.goToLast();
do {
final int indexInColumn = columnMemberships.getIndexInColumn();
if (indexInColumn >= lengthNonMissing) {
final double weight = columnMemberships.getRowWeight();
final int classIdx = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
targetCountsRightOfSplit[classIdx] -= weight;
missingTargetCounts[classIdx] += weight;
missingWeight += weight;
} else {
if (lastValidSplitPosition < 0) {
lastValidSplitPosition = indexInColumn;
} else if ((getSorted(lastValidSplitPosition) - getSorted(indexInColumn)) >= EPSILON) {
break;
} else {
lastValidSplitPosition = indexInColumn;
}
}
} while (columnMemberships.previous());
// it is possible that the column contains missing values but in the current branch there are no missing values
branchContainsMissingValues = missingWeight > 0.0;
columnMemberships.reset();
double sumWeightsLeftOfSplit = 0.0;
double sumWeightsRightOfSplit = totalSumWeight - missingWeight;
final double priorImpurity = useXGBoostMissingValueHandling || !branchContainsMissingValues ? targetPriors.getPriorImpurity() : impurityCriterion.getPartitionImpurity(TreeNominalColumnData.subtractMissingClassCounts(targetPriors.getDistribution(), missingTargetCounts), sumWeightsRightOfSplit);
// all values in branch are missing
if (sumWeightsRightOfSplit == 0) {
// it is impossible to determine a split
return null;
}
double bestSplit = Double.NEGATIVE_INFINITY;
// gain for best split point, unnormalized (not using info gain ratio)
double bestGain = Double.NEGATIVE_INFINITY;
// gain for best split, normalized by attribute entropy when
// info gain ratio is used.
double bestGainValueForSplit = Double.NEGATIVE_INFINITY;
final double[] tempArray1 = new double[2];
double[] tempArray2 = new double[2];
double lastSeenValue = Double.NEGATIVE_INFINITY;
boolean mustTestOnNextValueChange = false;
boolean testSplitOnStart = true;
boolean firstIteration = true;
int lastSeenTarget = -1;
int indexInCol = -1;
// We iterate over the instances in the sample/branch instead of the whole data set
while (columnMemberships.next() && (indexInCol = columnMemberships.getIndexInColumn()) < lengthNonMissing) {
final double weight = columnMemberships.getRowWeight();
assert weight >= EPSILON : "Rows with zero row weight should never be seen!";
final double value = getSorted(indexInCol);
final int target = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
final boolean hasValueChanged = (value - lastSeenValue) >= EPSILON;
final boolean hasTargetChanged = lastSeenTarget != target || indexInCol == lastValidSplitPosition;
if (hasTargetChanged && !firstIteration) {
mustTestOnNextValueChange = true;
testSplitOnStart = false;
}
if (!firstIteration && hasValueChanged && (mustTestOnNextValueChange || testSplitOnStart) && sumWeightsLeftOfSplit >= minChildNodeSize && sumWeightsRightOfSplit >= minChildNodeSize) {
double postSplitImpurity;
boolean tempMissingsGoLeft = false;
// missing value handling
if (branchContainsMissingValues && useXGBoostMissingValueHandling) {
final double[] targetCountsLeftPlusMissing = new double[targetCounts];
final double[] targetCountsRightPlusMissing = new double[targetCounts];
for (int i = 0; i < targetCounts; i++) {
targetCountsLeftPlusMissing[i] = targetCountsLeftOfSplit[i] + missingTargetCounts[i];
targetCountsRightPlusMissing[i] = targetCountsRightOfSplit[i] + missingTargetCounts[i];
}
final double[][] temp = new double[2][2];
final double[] postSplitImpurities = new double[2];
// send all missing values left
tempArray1[0] = impurityCriterion.getPartitionImpurity(targetCountsLeftPlusMissing, sumWeightsLeftOfSplit + missingWeight);
tempArray1[1] = impurityCriterion.getPartitionImpurity(targetCountsRightOfSplit, sumWeightsRightOfSplit);
temp[0][0] = sumWeightsLeftOfSplit + missingWeight;
temp[0][1] = sumWeightsRightOfSplit;
postSplitImpurities[0] = impurityCriterion.getPostSplitImpurity(tempArray1, temp[0], totalSumWeight);
// send all missing values right
tempArray1[0] = impurityCriterion.getPartitionImpurity(targetCountsLeftOfSplit, sumWeightsLeftOfSplit);
tempArray1[1] = impurityCriterion.getPartitionImpurity(targetCountsRightPlusMissing, sumWeightsRightOfSplit + missingWeight);
temp[1][0] = sumWeightsLeftOfSplit;
temp[1][1] = sumWeightsRightOfSplit + missingWeight;
postSplitImpurities[1] = impurityCriterion.getPostSplitImpurity(tempArray1, temp[1], totalSumWeight);
// take better split
if (postSplitImpurities[0] < postSplitImpurities[1]) {
postSplitImpurity = postSplitImpurities[0];
tempArray2 = temp[0];
tempMissingsGoLeft = true;
// TODO random tie breaker
} else {
postSplitImpurity = postSplitImpurities[1];
tempArray2 = temp[1];
tempMissingsGoLeft = false;
}
} else {
tempArray1[0] = impurityCriterion.getPartitionImpurity(targetCountsLeftOfSplit, sumWeightsLeftOfSplit);
tempArray1[1] = impurityCriterion.getPartitionImpurity(targetCountsRightOfSplit, sumWeightsRightOfSplit);
tempArray2[0] = sumWeightsLeftOfSplit;
tempArray2[1] = sumWeightsRightOfSplit;
postSplitImpurity = impurityCriterion.getPostSplitImpurity(tempArray1, tempArray2, totalSumWeight);
}
if (postSplitImpurity < priorImpurity) {
// Use absolute gain (IG) for split calculation even
// if the split criterion is information gain ratio (IGR).
// IGR wouldn't work as it favors extreme unfair splits,
// i.e. 1:9999 would have an attribute entropy
// (IGR denominator) of
// 9999/10000*log(9999/10000) + 1/10000*log(1/10000)
// which is ~0.00148
double gain = (priorImpurity - postSplitImpurity);
boolean randomTieBreaker = gain == bestGain ? rd.nextInt(0, 1) == 1 : false;
if (gain > bestGain || randomTieBreaker) {
bestGainValueForSplit = impurityCriterion.getGain(priorImpurity, postSplitImpurity, tempArray2, totalSumWeight);
bestGain = gain;
bestSplit = useAverageSplitPoints ? getCenter(lastSeenValue, value) : lastSeenValue;
// Go with the majority if there are no missing values during training this is because we should
// still provide a missing direction for the case that there are missing values during prediction
missingsGoLeft = branchContainsMissingValues ? tempMissingsGoLeft : sumWeightsLeftOfSplit > sumWeightsRightOfSplit;
}
}
mustTestOnNextValueChange = false;
}
targetCountsLeftOfSplit[target] += weight;
sumWeightsLeftOfSplit += weight;
targetCountsRightOfSplit[target] -= weight;
sumWeightsRightOfSplit -= weight;
lastSeenTarget = target;
lastSeenValue = value;
firstIteration = false;
}
columnMemberships.reset();
if (bestGainValueForSplit < 0.0) {
// (see info gain ratio implementation)
return null;
}
if (useXGBoostMissingValueHandling) {
// return new NumericMissingSplitCandidate(this, bestSplit, bestGainValueForSplit, missingsGoLeft);
return new NumericSplitCandidate(this, bestSplit, bestGainValueForSplit, new BitSet(), missingsGoLeft ? NumericSplitCandidate.MISSINGS_GO_LEFT : NumericSplitCandidate.MISSINGS_GO_RIGHT);
}
return new NumericSplitCandidate(this, bestSplit, bestGainValueForSplit, getMissedRows(columnMemberships), NumericSplitCandidate.NO_MISSINGS);
}
use of org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData in project knime-core by knime.
the class Surrogates method learnSurrogates.
/**
* This function searches for splits in the remaining columns of <b>colSample</b>. It is doing so by taking the
* directions (left or right) that are induced by the <b>bestSplit</b> as new target.
*
* @param dataMemberships provides information which rows are in the current branch
* @param bestSplit the best split for the current node
* @param oldData the TreeData object that contains all attributes and the target
* @param colSample provides information which columns are to be considered as surrogates
* @param config the configuration
* @param rd
* @return a SurrogateSplit that contains the conditions for both children
*/
public static SurrogateSplit learnSurrogates(final DataMemberships dataMemberships, final SplitCandidate bestSplit, final TreeData oldData, final ColumnSample colSample, final TreeEnsembleLearnerConfiguration config, final RandomData rd) {
TreeAttributeColumnData bestSplitCol = bestSplit.getColumnData();
TreeNodeCondition[] bestSplitChildConditions = bestSplit.getChildConditions();
// calculate new Target
BitSet bestSplitLeft = bestSplitCol.updateChildMemberships(bestSplitChildConditions[0], dataMemberships);
BitSet bestSplitRight = bestSplitCol.updateChildMemberships(bestSplitChildConditions[1], dataMemberships);
// create DataMemberships that only contains the instances that are not missed by bestSplit
BitSet surrogateBitSet = (BitSet) bestSplitLeft.clone();
surrogateBitSet.or(bestSplitRight);
DataMemberships surrogateCalcDataMemberships = dataMemberships.createChildMemberships(surrogateBitSet);
TreeTargetNominalColumnData newTarget = createNewTargetColumn(bestSplitLeft, bestSplitRight, oldData.getNrRows(), surrogateCalcDataMemberships);
// find best splits on new target
ArrayList<SplitCandidate> candidates = new ArrayList<SplitCandidate>();
ClassificationPriors newTargetPriors = newTarget.getDistribution(surrogateCalcDataMemberships, config);
for (TreeAttributeColumnData col : colSample) {
if (col != bestSplitCol) {
SplitCandidate candidate = col.calcBestSplitClassification(surrogateCalcDataMemberships, newTargetPriors, newTarget, rd);
if (candidate != null) {
candidates.add(candidate);
}
}
}
SplitCandidate[] candidatesWithBestAtHead = new SplitCandidate[candidates.size() + 1];
candidatesWithBestAtHead[0] = bestSplit;
for (int i = 1; i < candidatesWithBestAtHead.length; i++) {
candidatesWithBestAtHead[i] = candidates.get(i - 1);
}
return calculateSurrogates(dataMemberships, candidatesWithBestAtHead);
}
use of org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData in project knime-core by knime.
the class TreeLearnerClassification method buildTreeNode.
private TreeNodeClassification buildTreeNode(final ExecutionMonitor exec, final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) throws CanceledExecutionException {
final TreeData data = getData();
final TreeEnsembleLearnerConfiguration config = getConfig();
exec.checkCanceled();
final boolean useSurrogates = getConfig().getMissingValueHandling() == MissingValueHandling.Surrogate;
TreeNodeCondition[] childConditions;
boolean markAttributeAsForbidden = false;
final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
TreeNodeClassification[] childNodes;
int attributeIndex = -1;
if (useSurrogates) {
SplitCandidate[] candidates = findBestSplitsClassification(currentDepth, dataMemberships, columnSample, treeNodeSignature, targetPriors, forbiddenColumnSet);
if (candidates == null) {
return new TreeNodeClassification(treeNodeSignature, targetPriors, config);
}
SurrogateSplit surrogateSplit = Surrogates.learnSurrogates(dataMemberships, candidates[0], data, columnSample, config, getRandomData());
childConditions = surrogateSplit.getChildConditions();
BitSet[] childMarkers = surrogateSplit.getChildMarkers();
childNodes = new TreeNodeClassification[2];
for (int i = 0; i < 2; i++) {
DataMemberships childMemberships = dataMemberships.createChildMemberships(childMarkers[i]);
ClassificationPriors childTargetPriors = targetColumn.getDistribution(childMemberships, config);
TreeNodeSignature childSignature = getSignatureFactory().getChildSignatureFor(treeNodeSignature, (byte) i);
ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
childNodes[i].setTreeNodeCondition(childConditions[i]);
}
} else {
// handle non surrogate case
SplitCandidate bestSplit = findBestSplitClassification(currentDepth, dataMemberships, columnSample, treeNodeSignature, targetPriors, forbiddenColumnSet);
if (bestSplit == null) {
return new TreeNodeClassification(treeNodeSignature, targetPriors, config);
}
TreeAttributeColumnData splitColumn = bestSplit.getColumnData();
attributeIndex = splitColumn.getMetaData().getAttributeIndex();
markAttributeAsForbidden = !bestSplit.canColumnBeSplitFurther();
forbiddenColumnSet.set(attributeIndex, markAttributeAsForbidden);
childConditions = bestSplit.getChildConditions();
childNodes = new TreeNodeClassification[childConditions.length];
if (childConditions.length > Short.MAX_VALUE) {
throw new RuntimeException("Too many children when splitting " + "attribute " + bestSplit.getColumnData() + " (maximum supported: " + Short.MAX_VALUE + "): " + childConditions.length);
}
// Build child nodes
for (int i = 0; i < childConditions.length; i++) {
DataMemberships childMemberships = null;
TreeNodeCondition cond = childConditions[i];
childMemberships = dataMemberships.createChildMemberships(splitColumn.updateChildMemberships(cond, dataMemberships));
ClassificationPriors childTargetPriors = targetColumn.getDistribution(childMemberships, config);
TreeNodeSignature childSignature = treeNodeSignature.createChildSignature((byte) i);
ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
childNodes[i].setTreeNodeCondition(cond);
}
}
if (markAttributeAsForbidden) {
forbiddenColumnSet.set(attributeIndex, false);
}
return new TreeNodeClassification(treeNodeSignature, targetPriors, childNodes, getConfig());
}
use of org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData in project knime-core by knime.
the class TreeNominalColumnData method calcBestSplitClassificationBinaryTwoClass.
private NominalBinarySplitCandidate calcBestSplitClassificationBinaryTwoClass(final ColumnMemberships columnMemberships, final ClassificationPriors targetPriors, final TreeTargetNominalColumnData targetColumn, final IImpurity impCriterion, final NominalValueRepresentation[] nomVals, final NominalValueRepresentation[] targetVals, final RandomData rd) {
if (targetColumn.getMetaData().getValues().length != 2) {
throw new IllegalArgumentException("This method can only be used for two class problems.");
}
final TreeEnsembleLearnerConfiguration config = getConfiguration();
final int minChildSize = config.getMinChildSize();
final boolean useXGBoostMissingValueHandling = config.getMissingValueHandling() == MissingValueHandling.XGBoost;
int start = 0;
final int firstClass = targetColumn.getMetaData().getValues()[0].getAssignedInteger();
double totalWeight = 0.0;
double totalFirstClassWeight = 0.0;
final ArrayList<NomValProbabilityPair> nomValProbabilities = new ArrayList<NomValProbabilityPair>();
if (!columnMemberships.next()) {
throw new IllegalStateException("The columnMemberships has not been reset or is empty.");
}
final int lengthNonMissing = containsMissingValues() ? nomVals.length - 1 : nomVals.length;
// final int attToConsider = useXGBoostMissingValueHandling ? nomVals.length : lengthNonMissing;
boolean branchContainsMissingValues = containsMissingValues();
// calculate probabilities for first class in each nominal value
for (int att = 0; att < /*attToConsider*/
lengthNonMissing; att++) {
int end = start + m_nominalValueCounts[att];
double attFirstClassWeight = 0;
double attWeight = 0;
boolean reachedEnd = false;
for (int index = columnMemberships.getIndexInColumn(); index < end; index = columnMemberships.getIndexInColumn()) {
double weight = columnMemberships.getRowWeight();
assert weight > EPSILON : "Instances in columnMemberships must have weights larger than EPSILON.";
final int instanceClass = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
if (instanceClass == firstClass) {
attFirstClassWeight += weight;
totalFirstClassWeight += weight;
}
attWeight += weight;
totalWeight += weight;
if (!columnMemberships.next()) {
// reached end of columnMemberships
reachedEnd = true;
if (att == nomVals.length - 1) {
// if the column contains no missing values, the last possible nominal value is
// not the missing value and therefore branchContainsMissingValues needs to be false
branchContainsMissingValues = branchContainsMissingValues && true;
}
break;
}
}
if (attWeight > 0) {
final double firstClassProbability = attFirstClassWeight / attWeight;
final NominalValueRepresentation nomVal = getMetaData().getValues()[att];
nomValProbabilities.add(new NomValProbabilityPair(nomVal, firstClassProbability, attWeight, attFirstClassWeight));
}
start = end;
if (reachedEnd) {
break;
}
}
// account for missing values and their weight
double missingWeight = 0.0;
double missingWeightFirstClass = 0.0;
// otherwise the current indexInColumn won't be larger than start
if (columnMemberships.getIndexInColumn() >= start) {
do {
final double recordWeight = columnMemberships.getRowWeight();
missingWeight += recordWeight;
final int recordClass = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
if (recordClass == firstClass) {
missingWeightFirstClass += recordWeight;
}
} while (columnMemberships.next());
}
if (missingWeight > EPSILON) {
branchContainsMissingValues = true;
}
nomValProbabilities.sort(null);
int highestBitPosition = getMetaData().getValues().length - 1;
if (containsMissingValues()) {
highestBitPosition--;
}
final double[] targetCountsSplitPartition = new double[2];
final double[] targetCountsSplitRemaining = new double[2];
final double[] binaryImpurityValues = new double[2];
final double[] binaryPartitionWeights = new double[2];
BigInteger partitionMask = BigInteger.ZERO;
double bestPartitionGain = Double.NEGATIVE_INFINITY;
BigInteger bestPartitionMask = null;
boolean isBestSplitValid = false;
double sumWeightsPartitionTotal = 0.0;
double sumWeightsPartitionFirstClass = 0.0;
boolean missingsGoLeft = false;
final double priorImpurity = useXGBoostMissingValueHandling ? targetPriors.getPriorImpurity() : impCriterion.getPartitionImpurity(subtractMissingClassCounts(targetPriors.getDistribution(), createMissingClassCountsTwoClass(missingWeight, missingWeightFirstClass)), totalWeight);
// we don't need to iterate over the full list because we always need some value on the other side
for (int i = 0; i < nomValProbabilities.size() - 1; i++) {
NomValProbabilityPair nomVal = nomValProbabilities.get(i);
sumWeightsPartitionTotal += nomVal.m_sumWeights;
sumWeightsPartitionFirstClass += nomVal.m_firstClassSumWeights;
partitionMask = partitionMask.or(nomVal.m_bitMask);
// check if split represented by currentSplitList is in the right branch
// by convention a split goes towards the right branch if the highest possible bit is set to 1
final boolean isRightBranch = partitionMask.testBit(highestBitPosition);
double gain;
boolean isValidSplit;
boolean tempMissingsGoLeft = true;
if (branchContainsMissingValues && useXGBoostMissingValueHandling) {
// send missing values both ways and take the better direction
// send missings left
targetCountsSplitPartition[0] = sumWeightsPartitionFirstClass + missingWeightFirstClass;
targetCountsSplitPartition[1] = sumWeightsPartitionTotal + missingWeight - targetCountsSplitPartition[0];
binaryPartitionWeights[1] = sumWeightsPartitionTotal + missingWeight;
// totalFirstClassWeight and totalWeight only include non missing values
targetCountsSplitRemaining[0] = totalFirstClassWeight - sumWeightsPartitionFirstClass;
targetCountsSplitRemaining[1] = totalWeight - sumWeightsPartitionTotal - targetCountsSplitRemaining[0];
binaryPartitionWeights[0] = totalWeight - sumWeightsPartitionTotal;
boolean isValidSplitLeft = binaryPartitionWeights[0] >= minChildSize && binaryPartitionWeights[1] >= minChildSize;
binaryImpurityValues[0] = impCriterion.getPartitionImpurity(targetCountsSplitRemaining, binaryPartitionWeights[0]);
binaryImpurityValues[1] = impCriterion.getPartitionImpurity(targetCountsSplitPartition, binaryPartitionWeights[1]);
double postSplitImpurity = impCriterion.getPostSplitImpurity(binaryImpurityValues, binaryPartitionWeights, totalWeight + missingWeight);
double gainLeft = impCriterion.getGain(priorImpurity, postSplitImpurity, binaryPartitionWeights, totalWeight + missingWeight);
// send missings right
targetCountsSplitPartition[0] = sumWeightsPartitionFirstClass;
targetCountsSplitPartition[1] = sumWeightsPartitionTotal - sumWeightsPartitionFirstClass;
binaryPartitionWeights[1] = sumWeightsPartitionTotal;
targetCountsSplitRemaining[0] = totalFirstClassWeight - sumWeightsPartitionFirstClass + missingWeightFirstClass;
targetCountsSplitRemaining[1] = totalWeight - sumWeightsPartitionTotal + missingWeight - targetCountsSplitRemaining[0];
binaryPartitionWeights[0] = totalWeight + missingWeight - sumWeightsPartitionTotal;
boolean isValidSplitRight = binaryPartitionWeights[0] >= minChildSize && binaryPartitionWeights[1] >= minChildSize;
binaryImpurityValues[0] = impCriterion.getPartitionImpurity(targetCountsSplitRemaining, binaryPartitionWeights[0]);
binaryImpurityValues[1] = impCriterion.getPartitionImpurity(targetCountsSplitPartition, binaryPartitionWeights[1]);
postSplitImpurity = impCriterion.getPostSplitImpurity(binaryImpurityValues, binaryPartitionWeights, totalWeight + missingWeight);
double gainRight = impCriterion.getGain(priorImpurity, postSplitImpurity, binaryPartitionWeights, totalWeight + missingWeight);
// decide which is better (better gain)
if (gainLeft >= gainRight) {
gain = gainLeft;
isValidSplit = isValidSplitLeft;
tempMissingsGoLeft = true;
} else {
gain = gainRight;
isValidSplit = isValidSplitRight;
tempMissingsGoLeft = false;
}
} else {
// assign weights to branches
targetCountsSplitPartition[0] = sumWeightsPartitionFirstClass;
targetCountsSplitPartition[1] = sumWeightsPartitionTotal - sumWeightsPartitionFirstClass;
binaryPartitionWeights[1] = sumWeightsPartitionTotal;
targetCountsSplitRemaining[0] = totalFirstClassWeight - sumWeightsPartitionFirstClass;
targetCountsSplitRemaining[1] = totalWeight - sumWeightsPartitionTotal - targetCountsSplitRemaining[0];
binaryPartitionWeights[0] = totalWeight - sumWeightsPartitionTotal;
isValidSplit = binaryPartitionWeights[0] >= minChildSize && binaryPartitionWeights[1] >= minChildSize;
binaryImpurityValues[0] = impCriterion.getPartitionImpurity(targetCountsSplitRemaining, binaryPartitionWeights[0]);
binaryImpurityValues[1] = impCriterion.getPartitionImpurity(targetCountsSplitPartition, binaryPartitionWeights[1]);
double postSplitImpurity = impCriterion.getPostSplitImpurity(binaryImpurityValues, binaryPartitionWeights, totalWeight);
gain = impCriterion.getGain(priorImpurity, postSplitImpurity, binaryPartitionWeights, totalWeight);
}
// use random tie breaker if gains are equal
boolean randomTieBreaker = gain == bestPartitionGain ? rd.nextInt(0, 1) == 1 : false;
// store if better than before or first valid split
if (gain > bestPartitionGain || (!isBestSplitValid && isValidSplit) || randomTieBreaker) {
if (isValidSplit || !isBestSplitValid) {
bestPartitionGain = gain;
bestPartitionMask = isRightBranch ? partitionMask : BigInteger.ZERO.setBit(highestBitPosition + 1).subtract(BigInteger.ONE).xor(partitionMask);
isBestSplitValid = isValidSplit;
// missingsGoLeft is only used later on if XGBoost Missing Value Handling is used
if (branchContainsMissingValues) {
// missingsGoLeft = isRightBranch;
missingsGoLeft = tempMissingsGoLeft;
} else {
// no missing values in this branch
// send missing values with the majority
missingsGoLeft = isRightBranch ? sumWeightsPartitionTotal < 0.5 * totalWeight : sumWeightsPartitionTotal >= 0.5 * totalWeight;
}
}
}
}
if (isBestSplitValid && bestPartitionGain > 0.0) {
if (useXGBoostMissingValueHandling) {
return new NominalBinarySplitCandidate(this, bestPartitionGain, bestPartitionMask, NO_MISSED_ROWS, missingsGoLeft ? NominalBinarySplitCandidate.MISSINGS_GO_LEFT : NominalBinarySplitCandidate.MISSINGS_GO_RIGHT);
}
return new NominalBinarySplitCandidate(this, bestPartitionGain, bestPartitionMask, getMissedRows(columnMemberships), NominalBinarySplitCandidate.NO_MISSINGS);
}
return null;
}
Aggregations