use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.
the class TreeBitVectorColumnData method updateChildMemberships.
/**
* {@inheritDoc}
*/
@Override
public BitSet updateChildMemberships(final TreeNodeCondition childCondition, final DataMemberships parentMemberships) {
TreeNodeBitCondition bitCondition = (TreeNodeBitCondition) childCondition;
assert getMetaData().getAttributeName().equals(bitCondition.getColumnMetaData().getAttributeName());
final boolean value = bitCondition.getValue();
final ColumnMemberships columnMemberships = parentMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
BitSet inChild = new BitSet(columnMemberships.size());
columnMemberships.reset();
columnMemberships.next();
for (int i = columnMemberships.getIndexInColumn(); ; i = columnMemberships.getIndexInColumn()) {
if (m_columnBitSet.get(i) == value) {
inChild.set(columnMemberships.getIndexInDataMemberships());
}
if (!columnMemberships.next()) {
break;
}
}
return inChild;
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.
the class TreeNumericColumnData method updateChildMemberships.
@Override
public BitSet updateChildMemberships(final TreeNodeCondition childCondition, final DataMemberships parentMemberships) {
final TreeNodeNumericCondition numCondition = (TreeNodeNumericCondition) childCondition;
final NumericOperator numOperator = numCondition.getNumericOperator();
final double splitValue = numCondition.getSplitValue();
final ColumnMemberships columnMemberships = parentMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
columnMemberships.reset();
final BitSet inChild = new BitSet(columnMemberships.size());
int startIndex = 0;
// }
if (!columnMemberships.nextIndexFrom(startIndex)) {
throw new IllegalStateException("The current columnMemberships object contains no element that satisfies the splitcondition");
}
final int lengthNonMissing = getLengthNonMissing();
do {
final double value = getSorted(columnMemberships.getIndexInColumn());
boolean matches;
switch(numOperator) {
case LessThanOrEqual:
matches = value <= splitValue;
break;
case LargerThan:
matches = value > splitValue;
break;
case LessThanOrEqualOrMissing:
matches = Double.isNaN(value) ? true : value <= splitValue;
break;
case LargerThanOrMissing:
matches = Double.isNaN(value) ? true : value > splitValue;
break;
default:
throw new IllegalStateException("Unknown operator " + numOperator);
}
if (matches) {
inChild.set(columnMemberships.getIndexInDataMemberships());
}
} while (columnMemberships.next() && columnMemberships.getIndexInColumn() < lengthNonMissing);
// reached end of columnMemberships
if (columnMemberships.getIndexInColumn() < lengthNonMissing) {
return inChild;
}
// handle missing values
if (numOperator.equals(NumericOperator.LessThanOrEqualOrMissing) || numOperator.equals(NumericOperator.LargerThanOrMissing) || numCondition.acceptsMissings()) {
do {
inChild.set(columnMemberships.getIndexInDataMemberships());
} while (columnMemberships.next());
}
return inChild;
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.
the class Surrogates method calculateSurrogates.
/**
* This function finds the splits (in <b>candidates</b>) that best mirror the best split (<b>candidates[0]</b>). The
* splits are compared to the so called <i>majority split</i> that sends all records to the child that the most rows
* in the best split are sent to. This <i>majority split</i> is also always the last surrogate to guarantee that
* every record is sent to a child even if all surrogate attributes are also missing.
*
* @param dataMemberships
* @param candidates the first candidate must be the best split
* @return A SplitCandidate containing surrogates
*/
public static SurrogateSplit calculateSurrogates(final DataMemberships dataMemberships, final SplitCandidate[] candidates) {
final SplitCandidate bestSplit = candidates[0];
TreeAttributeColumnData bestSplitCol = bestSplit.getColumnData();
TreeNodeCondition[] bestSplitChildConditions = bestSplit.getChildConditions();
if (bestSplitChildConditions.length != 2) {
throw new IllegalArgumentException("Surrogates can only be calculated for binary splits.");
}
BitSet bestSplitLeft = bestSplitCol.updateChildMemberships(bestSplitChildConditions[0], dataMemberships);
BitSet bestSplitRight = bestSplitCol.updateChildMemberships(bestSplitChildConditions[1], dataMemberships);
final double numRowsInNode = dataMemberships.getRowCount();
// probability for a row to be in the current node
final double probInNode = numRowsInNode / dataMemberships.getRowCountInRoot();
// probability for a row to go left according to the best split
final double bestSplitProbLeft = bestSplitLeft.cardinality() / numRowsInNode;
// probability for a row to go right according to the best split
final double bestSplitProbRight = bestSplitRight.cardinality() / numRowsInNode;
// the majority rule is always the last surrogate and defines a default direction if all other
// surrogates fail
final boolean majorityGoesLeft = bestSplitProbRight > bestSplitProbLeft ? false : true;
// see calculatAssociationMeasure() for more information
final double errorMajorityRule = majorityGoesLeft ? bestSplitProbRight : bestSplitProbLeft;
// stores association measure for candidates
ArrayList<SurrogateCandidate> surrogateCandidates = new ArrayList<SurrogateCandidate>();
for (int i = 1; i < candidates.length; i++) {
SplitCandidate surrogate = candidates[i];
TreeAttributeColumnData surrogateCol = surrogate.getColumnData();
TreeNodeCondition[] surrogateChildConditions = surrogate.getChildConditions();
if (surrogateChildConditions.length != 2) {
throw new IllegalArgumentException("Surrogates can only be calculated for binary splits.");
}
BitSet surrogateLeft = surrogateCol.updateChildMemberships(surrogateChildConditions[0], dataMemberships);
BitSet surrogateRight = surrogateCol.updateChildMemberships(surrogateChildConditions[1], dataMemberships);
BitSet bothLeft = (BitSet) bestSplitLeft.clone();
bothLeft.and(surrogateLeft);
BitSet bothRight = (BitSet) bestSplitRight.clone();
bothRight.and(surrogateRight);
// the complement of a split (switching the children) has the same gain value as the original split
BitSet complementBothLeft = (BitSet) bestSplitLeft.clone();
complementBothLeft.and(surrogateRight);
BitSet complementBothRight = (BitSet) bestSplitRight.clone();
complementBothRight.and(surrogateLeft);
// calculating the probability that the surrogate candidate and the best split send a case both in the same
// direction is necessary because there might be missing values which are not send in either direction
double probBothLeft = (bothLeft.cardinality() / numRowsInNode);
double probBothRight = (bothRight.cardinality() / numRowsInNode);
// the relative probability that the surrogate predicts the best split correctly
double predictProb = probBothLeft + probBothRight;
double probComplementBothLeft = (complementBothLeft.cardinality() / numRowsInNode);
double probComplementBothRight = (complementBothRight.cardinality() / numRowsInNode);
double complementPredictProb = probComplementBothLeft + probComplementBothRight;
double associationMeasure = calculateAssociationMeasure(errorMajorityRule, predictProb);
double complementAssociationMeasure = calculateAssociationMeasure(errorMajorityRule, complementPredictProb);
boolean useComplement = complementAssociationMeasure > associationMeasure ? true : false;
double betterAssociationMeasure = useComplement ? complementAssociationMeasure : associationMeasure;
assert betterAssociationMeasure <= 1 : "Association measure can not be greater than 1.";
if (betterAssociationMeasure > 0) {
BitSet[] childMarkers = new BitSet[] { surrogateLeft, surrogateRight };
surrogateCandidates.add(new SurrogateCandidate(surrogate, useComplement, betterAssociationMeasure, childMarkers));
}
}
BitSet[] childMarkers = new BitSet[] { bestSplitLeft, bestSplitRight };
// if there are no surrogates, create condition with default rule as only surrogate
if (surrogateCandidates.isEmpty()) {
fillInMissingChildMarkers(bestSplit, childMarkers, surrogateCandidates, majorityGoesLeft);
return new SurrogateSplit(new AbstractTreeNodeSurrogateCondition[] { new TreeNodeSurrogateOnlyDefDirCondition((TreeNodeColumnCondition) bestSplitChildConditions[0], majorityGoesLeft), new TreeNodeSurrogateOnlyDefDirCondition((TreeNodeColumnCondition) bestSplitChildConditions[1], !majorityGoesLeft) }, childMarkers);
}
surrogateCandidates.sort(null);
int condSize = surrogateCandidates.size() + 1;
TreeNodeColumnCondition[] conditionsLeftChild = new TreeNodeColumnCondition[condSize];
TreeNodeColumnCondition[] conditionsRightChild = new TreeNodeColumnCondition[condSize];
conditionsLeftChild[0] = (TreeNodeColumnCondition) bestSplitChildConditions[0];
conditionsRightChild[0] = (TreeNodeColumnCondition) bestSplitChildConditions[1];
for (int i = 0; i < surrogateCandidates.size(); i++) {
SurrogateCandidate surrogateCandidate = surrogateCandidates.get(i);
TreeNodeCondition[] surrogateConditions = surrogateCandidate.getSplitCandidate().getChildConditions();
if (surrogateCandidate.m_useComplement) {
conditionsLeftChild[i + 1] = (TreeNodeColumnCondition) surrogateConditions[1];
conditionsRightChild[i + 1] = (TreeNodeColumnCondition) surrogateConditions[0];
} else {
conditionsLeftChild[i + 1] = (TreeNodeColumnCondition) surrogateConditions[0];
conditionsRightChild[i + 1] = (TreeNodeColumnCondition) surrogateConditions[1];
}
}
// check if there are any rows missing in the best split
if (!bestSplit.getMissedRows().isEmpty()) {
// fill in any missing child markers
fillInMissingChildMarkers(bestSplit, childMarkers, surrogateCandidates, majorityGoesLeft);
}
return new SurrogateSplit(new TreeNodeSurrogateCondition[] { new TreeNodeSurrogateCondition(conditionsLeftChild, majorityGoesLeft), new TreeNodeSurrogateCondition(conditionsRightChild, !majorityGoesLeft) }, childMarkers);
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.
the class Surrogates method createSurrogateSplitWithDefaultDirection.
/**
* Creates a surrogate split that only contains the best split and the default (majority) direction. It does
* <b>NOT</b> calculate any surrogate splits (and is therefore more efficient).
*
* @param dataMemberships
* @param bestSplit
* @return SurrogateSplit with conditions for both children. The conditions only contain the condition for the best
* split and the default condition (true for the child the most records go to and false for the other one).
*/
public static SurrogateSplit createSurrogateSplitWithDefaultDirection(final DataMemberships dataMemberships, final SplitCandidate bestSplit) {
TreeAttributeColumnData col = bestSplit.getColumnData();
TreeNodeCondition[] conditions = bestSplit.getChildConditions();
// get child marker for best split
BitSet left = col.updateChildMemberships(conditions[0], dataMemberships);
BitSet right = col.updateChildMemberships(conditions[1], dataMemberships);
// decide which child the majority of the records goes to
boolean majorityGoesLeft = left.cardinality() < right.cardinality() ? false : true;
// create surrogate conditions
TreeNodeSurrogateOnlyDefDirCondition condLeft = new TreeNodeSurrogateOnlyDefDirCondition((TreeNodeColumnCondition) conditions[0], majorityGoesLeft);
TreeNodeSurrogateOnlyDefDirCondition condRight = new TreeNodeSurrogateOnlyDefDirCondition((TreeNodeColumnCondition) conditions[1], !majorityGoesLeft);
BitSet[] childMarkers = new BitSet[] { left, right };
fillInMissingChildMarkersWithDefault(bestSplit, childMarkers, majorityGoesLeft);
return new SurrogateSplit(new AbstractTreeNodeSurrogateCondition[] { condLeft, condRight }, new BitSet[] { left, right });
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.
the class TreeNumericColumnDataTest method testUpdateChildMemberships.
/**
* Tests the {@link TreeNumericColumnData#updateChildMemberships(TreeNodeCondition, DataMemberships)} methods with
* different conditions including missing values.
*
* @throws Exception
*/
@Test
public void testUpdateChildMemberships() throws Exception {
final TreeEnsembleLearnerConfiguration config = createConfig();
final TestDataGenerator dataGen = new TestDataGenerator(config);
final int[] indices = new int[] { 0, 1, 2, 3, 4, 5, 6 };
final double[] weights = new double[7];
Arrays.fill(weights, 1.0);
final DataMemberships dataMem = new MockDataColMem(indices, indices, weights);
final String noMissingsCSV = "-50, -3, -2, 2, 25, 100, 101";
final TreeNumericColumnData col = dataGen.createNumericAttributeColumn(noMissingsCSV, "noMissings-col", 0);
// less than or equals
TreeNodeNumericCondition numCond = new TreeNodeNumericCondition(col.getMetaData(), -2, NumericOperator.LessThanOrEqual, false);
BitSet inChild = col.updateChildMemberships(numCond, dataMem);
BitSet expected = new BitSet(3);
expected.set(0, 3);
assertEquals("The produced BitSet is incorrect.", expected, inChild);
// greater than
numCond = new TreeNodeNumericCondition(col.getMetaData(), 10, NumericOperator.LargerThan, false);
inChild = col.updateChildMemberships(numCond, dataMem);
expected.clear();
expected.set(4, 7);
assertEquals("The produced BitSet is incorrect", expected, inChild);
// with missing values
final String missingsCSV = "-2, 0, 1, 43, 61, 66, NaN";
final TreeNumericColumnData colWithMissings = dataGen.createNumericAttributeColumn(missingsCSV, "missings-col", 0);
// less than or equal or missing
numCond = new TreeNodeNumericCondition(colWithMissings.getMetaData(), 12, NumericOperator.LessThanOrEqual, true);
inChild = colWithMissings.updateChildMemberships(numCond, dataMem);
expected.clear();
expected.set(0, 3);
expected.set(6);
assertEquals("The produced BitSet is incorrect", expected, inChild);
// less than or equals not missing
numCond = new TreeNodeNumericCondition(colWithMissings.getMetaData(), 12, NumericOperator.LessThanOrEqual, false);
inChild = colWithMissings.updateChildMemberships(numCond, dataMem);
expected.clear();
expected.set(0, 3);
assertEquals("The produced BitSet is incorrect", expected, inChild);
// larger than or missing
numCond = new TreeNodeNumericCondition(colWithMissings.getMetaData(), 43, NumericOperator.LargerThan, true);
inChild = colWithMissings.updateChildMemberships(numCond, dataMem);
expected.clear();
expected.set(4, 7);
assertEquals("The produced BitSet is incorrect", expected, inChild);
// larger than not missing
numCond = new TreeNodeNumericCondition(colWithMissings.getMetaData(), 12, NumericOperator.LargerThan, false);
inChild = colWithMissings.updateChildMemberships(numCond, dataMem);
expected.clear();
expected.set(3, 6);
assertEquals("The produced BitSet is incorrect", expected, inChild);
}
Aggregations