use of org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships in project knime-core by knime.
the class TreeBitVectorColumnData method calcBestSplitClassification.
/**
* {@inheritDoc}
*/
@Override
public SplitCandidate calcBestSplitClassification(final DataMemberships dataMemberships, final ClassificationPriors targetPriors, final TreeTargetNominalColumnData targetColumn, final RandomData rd) {
final NominalValueRepresentation[] targetVals = targetColumn.getMetaData().getValues();
final IImpurity impurityCriterion = targetPriors.getImpurityCriterion();
final int minChildSize = getConfiguration().getMinChildSize();
// distribution of target for On ('1') and Off ('0') bits
final double[] onTargetWeights = new double[targetVals.length];
final double[] offTargetWeights = new double[targetVals.length];
double onWeights = 0.0;
double offWeights = 0.0;
final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
while (columnMemberships.next()) {
final double weight = columnMemberships.getRowWeight();
if (weight < EPSILON) {
// ignore record: not in current branch or not in sample
assert false : "This code should never be reached!";
} else {
final int target = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
if (m_columnBitSet.get(columnMemberships.getIndexInColumn())) {
onWeights += weight;
onTargetWeights[target] += weight;
} else {
offWeights += weight;
offTargetWeights[target] += weight;
}
}
}
if (onWeights < minChildSize || offWeights < minChildSize) {
return null;
}
final double weightSum = onWeights + offWeights;
final double onImpurity = impurityCriterion.getPartitionImpurity(onTargetWeights, onWeights);
final double offImpurity = impurityCriterion.getPartitionImpurity(offTargetWeights, offWeights);
final double[] partitionWeights = new double[] { onWeights, offWeights };
final double postSplitImpurity = impurityCriterion.getPostSplitImpurity(new double[] { onImpurity, offImpurity }, partitionWeights, weightSum);
final double gainValue = impurityCriterion.getGain(targetPriors.getPriorImpurity(), postSplitImpurity, partitionWeights, weightSum);
return new BitSplitCandidate(this, gainValue);
}
use of org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships in project knime-core by knime.
the class TreeNumericColumnData method updateChildMemberships.
@Override
public BitSet updateChildMemberships(final TreeNodeCondition childCondition, final DataMemberships parentMemberships) {
final TreeNodeNumericCondition numCondition = (TreeNodeNumericCondition) childCondition;
final NumericOperator numOperator = numCondition.getNumericOperator();
final double splitValue = numCondition.getSplitValue();
final ColumnMemberships columnMemberships = parentMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
columnMemberships.reset();
final BitSet inChild = new BitSet(columnMemberships.size());
int startIndex = 0;
// }
if (!columnMemberships.nextIndexFrom(startIndex)) {
throw new IllegalStateException("The current columnMemberships object contains no element that satisfies the splitcondition");
}
final int lengthNonMissing = getLengthNonMissing();
do {
final double value = getSorted(columnMemberships.getIndexInColumn());
boolean matches;
switch(numOperator) {
case LessThanOrEqual:
matches = value <= splitValue;
break;
case LargerThan:
matches = value > splitValue;
break;
case LessThanOrEqualOrMissing:
matches = Double.isNaN(value) ? true : value <= splitValue;
break;
case LargerThanOrMissing:
matches = Double.isNaN(value) ? true : value > splitValue;
break;
default:
throw new IllegalStateException("Unknown operator " + numOperator);
}
if (matches) {
inChild.set(columnMemberships.getIndexInDataMemberships());
}
} while (columnMemberships.next() && columnMemberships.getIndexInColumn() < lengthNonMissing);
// reached end of columnMemberships
if (columnMemberships.getIndexInColumn() < lengthNonMissing) {
return inChild;
}
// handle missing values
if (numOperator.equals(NumericOperator.LessThanOrEqualOrMissing) || numOperator.equals(NumericOperator.LargerThanOrMissing) || numCondition.acceptsMissings()) {
do {
inChild.set(columnMemberships.getIndexInDataMemberships());
} while (columnMemberships.next());
}
return inChild;
}
use of org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships in project knime-core by knime.
the class RootDescendantDataMembershipsTest method testGetColumnMemberships.
@Test
public void testGetColumnMemberships() {
TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
TestDataGenerator dataGen = new TestDataGenerator(config);
TreeData data = dataGen.createTennisData();
DefaultDataIndexManager indexManager = new DefaultDataIndexManager(data);
int nrRows = data.getNrRows();
RowSample rowSample = new DefaultRowSample(nrRows);
RootDataMemberships rootMemberships = new RootDataMemberships(rowSample, data, indexManager);
ColumnMemberships rootColMem = rootMemberships.getColumnMemberships(0);
assertThat(rootColMem, instanceOf(IntArrayColumnMemberships.class));
assertEquals(nrRows, rootColMem.size());
int[] expectedOriginalIndices = new int[] { 0, 1, 7, 8, 10, 2, 6, 11, 12, 3, 4, 5, 9, 13 };
for (int i = 0; rootColMem.next(); i++) {
// in this case originalIndex and indexInDataMemberships are the same
assertEquals(expectedOriginalIndices[i], rootColMem.getIndexInDataMemberships());
assertEquals(expectedOriginalIndices[i], rootColMem.getIndexInDataMemberships());
assertEquals(i, rootColMem.getIndexInColumn());
}
BitSet lastHalf = new BitSet(nrRows);
lastHalf.set(nrRows / 2, nrRows);
DataMemberships lastHalfChild = rootMemberships.createChildMemberships(lastHalf);
ColumnMemberships childColMem = lastHalfChild.getColumnMemberships(0);
assertThat(childColMem, instanceOf(DescendantColumnMemberships.class));
assertEquals(nrRows / 2, childColMem.size());
expectedOriginalIndices = new int[] { 7, 8, 10, 11, 12, 9, 13 };
int[] expectedIndexInColumn = new int[] { 2, 3, 4, 7, 8, 12, 13 };
int[] expectedIndexInDataMemberships = new int[] { 7, 8, 10, 11, 12, 9, 13 };
for (int i = 0; childColMem.next(); i++) {
assertEquals(expectedOriginalIndices[i], childColMem.getOriginalIndex());
assertEquals(expectedIndexInColumn[i], childColMem.getIndexInColumn());
assertEquals(expectedIndexInDataMemberships[i], childColMem.getIndexInDataMemberships());
}
}
use of org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships in project knime-core by knime.
the class TreeNumericColumnData method calcBestSplitRegression.
@Override
public SplitCandidate calcBestSplitRegression(final DataMemberships dataMemberships, final RegressionPriors targetPriors, final TreeTargetNumericColumnData targetColumn, final RandomData rd) {
final TreeEnsembleLearnerConfiguration config = getConfiguration();
final boolean useAverageSplitPoints = config.isUseAverageSplitPoints();
final int minChildNodeSize = config.getMinChildSize();
// get columnMemberships
final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
final int lengthNonMissing = getLengthNonMissing();
// missing value handling
final boolean useXGBoostMissingValueHandling = config.getMissingValueHandling() == MissingValueHandling.XGBoost;
// are there missing values in this column (complete column)
boolean branchContainsMissingValues = containsMissingValues();
boolean missingsGoLeft = true;
double missingWeight = 0.0;
double missingY = 0.0;
// check if there are missing values in this rowsample
if (branchContainsMissingValues) {
columnMemberships.goToLast();
while (columnMemberships.getIndexInColumn() >= lengthNonMissing) {
missingWeight += columnMemberships.getRowWeight();
missingY += targetColumn.getValueFor(columnMemberships.getOriginalIndex());
if (!columnMemberships.previous()) {
break;
}
}
columnMemberships.reset();
branchContainsMissingValues = missingWeight > 0.0;
}
final double ySumTotal = targetPriors.getYSum() - missingY;
final double nrRecordsTotal = targetPriors.getNrRecords() - missingWeight;
final double criterionTotal = useXGBoostMissingValueHandling ? (ySumTotal + missingY) * (ySumTotal + missingY) / (nrRecordsTotal + missingWeight) : ySumTotal * ySumTotal / nrRecordsTotal;
double ySumLeft = 0.0;
double nrRecordsLeft = 0.0;
double ySumRight = ySumTotal;
double nrRecordsRight = nrRecordsTotal;
// all values in the current branch are missing
if (nrRecordsRight == 0) {
// it is impossible to determine a split
return null;
}
double bestSplit = Double.NEGATIVE_INFINITY;
double bestImprovement = 0.0;
double lastSeenY = Double.NaN;
double lastSeenValue = Double.NEGATIVE_INFINITY;
double lastSeenWeight = -1.0;
// compute the gain, keep the one that maximizes the split
while (columnMemberships.next()) {
final double weight = columnMemberships.getRowWeight();
if (weight < EPSILON) {
// ignore record: not in current branch or not in sample
continue;
} else if (Math.floor(weight) != weight) {
throw new UnsupportedOperationException("weighted records (missing values?) not supported, " + "weight is " + weight);
}
final double value = getSorted(columnMemberships.getIndexInColumn());
if (lastSeenWeight > 0.0) {
ySumLeft += lastSeenWeight * lastSeenY;
ySumRight -= lastSeenWeight * lastSeenY;
nrRecordsLeft += lastSeenWeight;
nrRecordsRight -= lastSeenWeight;
if (nrRecordsLeft >= minChildNodeSize && nrRecordsRight >= minChildNodeSize && lastSeenValue < value) {
boolean tempMissingsGoLeft = true;
double childrenSquaredSum;
if (branchContainsMissingValues && useXGBoostMissingValueHandling) {
final double[] tempChildrenSquaredSum = new double[2];
tempChildrenSquaredSum[0] = ((ySumLeft + missingY) * (ySumLeft + missingY) / (nrRecordsLeft + missingWeight)) + (ySumRight * ySumRight / nrRecordsRight);
tempChildrenSquaredSum[1] = (ySumLeft * ySumLeft / nrRecordsLeft) + ((ySumRight + missingY) * (ySumRight + missingY) / (nrRecordsRight + missingWeight));
if (tempChildrenSquaredSum[0] >= tempChildrenSquaredSum[1]) {
childrenSquaredSum = tempChildrenSquaredSum[0];
tempMissingsGoLeft = true;
} else {
childrenSquaredSum = tempChildrenSquaredSum[1];
tempMissingsGoLeft = false;
}
} else {
childrenSquaredSum = (ySumLeft * ySumLeft / nrRecordsLeft) + (ySumRight * ySumRight / nrRecordsRight);
}
double criterion = childrenSquaredSum - criterionTotal;
boolean randomTieBreaker = criterion == bestImprovement ? rd.nextInt(0, 1) == 1 : false;
if (criterion > bestImprovement || randomTieBreaker) {
bestImprovement = criterion;
bestSplit = useAverageSplitPoints ? getCenter(lastSeenValue, value) : lastSeenValue;
// if there are no missing values go with majority
missingsGoLeft = branchContainsMissingValues ? tempMissingsGoLeft : nrRecordsLeft >= nrRecordsRight;
}
}
}
lastSeenY = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
lastSeenValue = value;
lastSeenWeight = weight;
}
// + " but was " + lastSeenY * lastSeenWeight;
if (bestImprovement > 0.0) {
if (useXGBoostMissingValueHandling) {
// return new NumericMissingSplitCandidate(this, bestSplit, bestImprovement, missingsGoLeft);
return new NumericSplitCandidate(this, bestSplit, bestImprovement, new BitSet(), missingsGoLeft ? NumericSplitCandidate.MISSINGS_GO_LEFT : NumericSplitCandidate.MISSINGS_GO_RIGHT);
}
return new NumericSplitCandidate(this, bestSplit, bestImprovement, getMissedRows(columnMemberships), NumericSplitCandidate.NO_MISSINGS);
} else {
return null;
}
}
use of org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships in project knime-core by knime.
the class TreeNumericColumnData method calcBestSplitClassification.
@Override
public NumericSplitCandidate calcBestSplitClassification(final DataMemberships dataMemberships, final ClassificationPriors targetPriors, final TreeTargetNominalColumnData targetColumn, final RandomData rd) {
final TreeEnsembleLearnerConfiguration config = getConfiguration();
final NominalValueRepresentation[] targetVals = targetColumn.getMetaData().getValues();
final boolean useAverageSplitPoints = config.isUseAverageSplitPoints();
final int minChildNodeSize = config.getMinChildSize();
// distribution of target for each attribute value
final int targetCounts = targetVals.length;
final double[] targetCountsLeftOfSplit = new double[targetCounts];
final double[] targetCountsRightOfSplit = targetPriors.getDistribution().clone();
assert targetCountsRightOfSplit.length == targetCounts;
final double totalSumWeight = targetPriors.getNrRecords();
final IImpurity impurityCriterion = targetPriors.getImpurityCriterion();
final boolean useXGBoostMissingValueHandling = config.getMissingValueHandling() == MissingValueHandling.XGBoost;
// get columnMemberships
final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
// missing value handling
boolean branchContainsMissingValues = containsMissingValues();
boolean missingsGoLeft = true;
final int lengthNonMissing = getLengthNonMissing();
final double[] missingTargetCounts = new double[targetCounts];
int lastValidSplitPosition = -1;
double missingWeight = 0;
columnMemberships.goToLast();
do {
final int indexInColumn = columnMemberships.getIndexInColumn();
if (indexInColumn >= lengthNonMissing) {
final double weight = columnMemberships.getRowWeight();
final int classIdx = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
targetCountsRightOfSplit[classIdx] -= weight;
missingTargetCounts[classIdx] += weight;
missingWeight += weight;
} else {
if (lastValidSplitPosition < 0) {
lastValidSplitPosition = indexInColumn;
} else if ((getSorted(lastValidSplitPosition) - getSorted(indexInColumn)) >= EPSILON) {
break;
} else {
lastValidSplitPosition = indexInColumn;
}
}
} while (columnMemberships.previous());
// it is possible that the column contains missing values but in the current branch there are no missing values
branchContainsMissingValues = missingWeight > 0.0;
columnMemberships.reset();
double sumWeightsLeftOfSplit = 0.0;
double sumWeightsRightOfSplit = totalSumWeight - missingWeight;
final double priorImpurity = useXGBoostMissingValueHandling || !branchContainsMissingValues ? targetPriors.getPriorImpurity() : impurityCriterion.getPartitionImpurity(TreeNominalColumnData.subtractMissingClassCounts(targetPriors.getDistribution(), missingTargetCounts), sumWeightsRightOfSplit);
// all values in branch are missing
if (sumWeightsRightOfSplit == 0) {
// it is impossible to determine a split
return null;
}
double bestSplit = Double.NEGATIVE_INFINITY;
// gain for best split point, unnormalized (not using info gain ratio)
double bestGain = Double.NEGATIVE_INFINITY;
// gain for best split, normalized by attribute entropy when
// info gain ratio is used.
double bestGainValueForSplit = Double.NEGATIVE_INFINITY;
final double[] tempArray1 = new double[2];
double[] tempArray2 = new double[2];
double lastSeenValue = Double.NEGATIVE_INFINITY;
boolean mustTestOnNextValueChange = false;
boolean testSplitOnStart = true;
boolean firstIteration = true;
int lastSeenTarget = -1;
int indexInCol = -1;
// We iterate over the instances in the sample/branch instead of the whole data set
while (columnMemberships.next() && (indexInCol = columnMemberships.getIndexInColumn()) < lengthNonMissing) {
final double weight = columnMemberships.getRowWeight();
assert weight >= EPSILON : "Rows with zero row weight should never be seen!";
final double value = getSorted(indexInCol);
final int target = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
final boolean hasValueChanged = (value - lastSeenValue) >= EPSILON;
final boolean hasTargetChanged = lastSeenTarget != target || indexInCol == lastValidSplitPosition;
if (hasTargetChanged && !firstIteration) {
mustTestOnNextValueChange = true;
testSplitOnStart = false;
}
if (!firstIteration && hasValueChanged && (mustTestOnNextValueChange || testSplitOnStart) && sumWeightsLeftOfSplit >= minChildNodeSize && sumWeightsRightOfSplit >= minChildNodeSize) {
double postSplitImpurity;
boolean tempMissingsGoLeft = false;
// missing value handling
if (branchContainsMissingValues && useXGBoostMissingValueHandling) {
final double[] targetCountsLeftPlusMissing = new double[targetCounts];
final double[] targetCountsRightPlusMissing = new double[targetCounts];
for (int i = 0; i < targetCounts; i++) {
targetCountsLeftPlusMissing[i] = targetCountsLeftOfSplit[i] + missingTargetCounts[i];
targetCountsRightPlusMissing[i] = targetCountsRightOfSplit[i] + missingTargetCounts[i];
}
final double[][] temp = new double[2][2];
final double[] postSplitImpurities = new double[2];
// send all missing values left
tempArray1[0] = impurityCriterion.getPartitionImpurity(targetCountsLeftPlusMissing, sumWeightsLeftOfSplit + missingWeight);
tempArray1[1] = impurityCriterion.getPartitionImpurity(targetCountsRightOfSplit, sumWeightsRightOfSplit);
temp[0][0] = sumWeightsLeftOfSplit + missingWeight;
temp[0][1] = sumWeightsRightOfSplit;
postSplitImpurities[0] = impurityCriterion.getPostSplitImpurity(tempArray1, temp[0], totalSumWeight);
// send all missing values right
tempArray1[0] = impurityCriterion.getPartitionImpurity(targetCountsLeftOfSplit, sumWeightsLeftOfSplit);
tempArray1[1] = impurityCriterion.getPartitionImpurity(targetCountsRightPlusMissing, sumWeightsRightOfSplit + missingWeight);
temp[1][0] = sumWeightsLeftOfSplit;
temp[1][1] = sumWeightsRightOfSplit + missingWeight;
postSplitImpurities[1] = impurityCriterion.getPostSplitImpurity(tempArray1, temp[1], totalSumWeight);
// take better split
if (postSplitImpurities[0] < postSplitImpurities[1]) {
postSplitImpurity = postSplitImpurities[0];
tempArray2 = temp[0];
tempMissingsGoLeft = true;
// TODO random tie breaker
} else {
postSplitImpurity = postSplitImpurities[1];
tempArray2 = temp[1];
tempMissingsGoLeft = false;
}
} else {
tempArray1[0] = impurityCriterion.getPartitionImpurity(targetCountsLeftOfSplit, sumWeightsLeftOfSplit);
tempArray1[1] = impurityCriterion.getPartitionImpurity(targetCountsRightOfSplit, sumWeightsRightOfSplit);
tempArray2[0] = sumWeightsLeftOfSplit;
tempArray2[1] = sumWeightsRightOfSplit;
postSplitImpurity = impurityCriterion.getPostSplitImpurity(tempArray1, tempArray2, totalSumWeight);
}
if (postSplitImpurity < priorImpurity) {
// Use absolute gain (IG) for split calculation even
// if the split criterion is information gain ratio (IGR).
// IGR wouldn't work as it favors extreme unfair splits,
// i.e. 1:9999 would have an attribute entropy
// (IGR denominator) of
// 9999/10000*log(9999/10000) + 1/10000*log(1/10000)
// which is ~0.00148
double gain = (priorImpurity - postSplitImpurity);
boolean randomTieBreaker = gain == bestGain ? rd.nextInt(0, 1) == 1 : false;
if (gain > bestGain || randomTieBreaker) {
bestGainValueForSplit = impurityCriterion.getGain(priorImpurity, postSplitImpurity, tempArray2, totalSumWeight);
bestGain = gain;
bestSplit = useAverageSplitPoints ? getCenter(lastSeenValue, value) : lastSeenValue;
// Go with the majority if there are no missing values during training this is because we should
// still provide a missing direction for the case that there are missing values during prediction
missingsGoLeft = branchContainsMissingValues ? tempMissingsGoLeft : sumWeightsLeftOfSplit > sumWeightsRightOfSplit;
}
}
mustTestOnNextValueChange = false;
}
targetCountsLeftOfSplit[target] += weight;
sumWeightsLeftOfSplit += weight;
targetCountsRightOfSplit[target] -= weight;
sumWeightsRightOfSplit -= weight;
lastSeenTarget = target;
lastSeenValue = value;
firstIteration = false;
}
columnMemberships.reset();
if (bestGainValueForSplit < 0.0) {
// (see info gain ratio implementation)
return null;
}
if (useXGBoostMissingValueHandling) {
// return new NumericMissingSplitCandidate(this, bestSplit, bestGainValueForSplit, missingsGoLeft);
return new NumericSplitCandidate(this, bestSplit, bestGainValueForSplit, new BitSet(), missingsGoLeft ? NumericSplitCandidate.MISSINGS_GO_LEFT : NumericSplitCandidate.MISSINGS_GO_RIGHT);
}
return new NumericSplitCandidate(this, bestSplit, bestGainValueForSplit, getMissedRows(columnMemberships), NumericSplitCandidate.NO_MISSINGS);
}
Aggregations