use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSurrogateCondition in project knime-core by knime.
the class RandomForestClassificationTreeNodeWidget method getConnectorLabelBelow.
/**
* {@inheritDoc}
*/
@Override
public String getConnectorLabelBelow() {
TreeNodeClassification node = (TreeNodeClassification) getUserObject();
if (node.getNrChildren() != 0) {
TreeNodeClassification child = node.getChild(0);
TreeNodeCondition childCondition = child.getCondition();
if (childCondition instanceof TreeNodeColumnCondition) {
return ((TreeNodeColumnCondition) childCondition).getAttributeName();
} else if (childCondition instanceof TreeNodeSurrogateCondition) {
TreeNodeSurrogateCondition surrogateCondition = (TreeNodeSurrogateCondition) childCondition;
TreeNodeCondition headCondition = surrogateCondition.getFirstCondition();
if (headCondition instanceof TreeNodeColumnCondition) {
return ((TreeNodeColumnCondition) headCondition).getAttributeName();
}
}
}
return null;
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSurrogateCondition in project knime-core by knime.
the class Surrogates method calculateSurrogates.
/**
* This function finds the splits (in <b>candidates</b>) that best mirror the best split (<b>candidates[0]</b>). The
* splits are compared to the so called <i>majority split</i> that sends all records to the child that the most rows
* in the best split are sent to. This <i>majority split</i> is also always the last surrogate to guarantee that
* every record is sent to a child even if all surrogate attributes are also missing.
*
* @param dataMemberships
* @param candidates the first candidate must be the best split
* @return A SplitCandidate containing surrogates
*/
public static SurrogateSplit calculateSurrogates(final DataMemberships dataMemberships, final SplitCandidate[] candidates) {
final SplitCandidate bestSplit = candidates[0];
TreeAttributeColumnData bestSplitCol = bestSplit.getColumnData();
TreeNodeCondition[] bestSplitChildConditions = bestSplit.getChildConditions();
if (bestSplitChildConditions.length != 2) {
throw new IllegalArgumentException("Surrogates can only be calculated for binary splits.");
}
BitSet bestSplitLeft = bestSplitCol.updateChildMemberships(bestSplitChildConditions[0], dataMemberships);
BitSet bestSplitRight = bestSplitCol.updateChildMemberships(bestSplitChildConditions[1], dataMemberships);
final double numRowsInNode = dataMemberships.getRowCount();
// probability for a row to be in the current node
final double probInNode = numRowsInNode / dataMemberships.getRowCountInRoot();
// probability for a row to go left according to the best split
final double bestSplitProbLeft = bestSplitLeft.cardinality() / numRowsInNode;
// probability for a row to go right according to the best split
final double bestSplitProbRight = bestSplitRight.cardinality() / numRowsInNode;
// the majority rule is always the last surrogate and defines a default direction if all other
// surrogates fail
final boolean majorityGoesLeft = bestSplitProbRight > bestSplitProbLeft ? false : true;
// see calculatAssociationMeasure() for more information
final double errorMajorityRule = majorityGoesLeft ? bestSplitProbRight : bestSplitProbLeft;
// stores association measure for candidates
ArrayList<SurrogateCandidate> surrogateCandidates = new ArrayList<SurrogateCandidate>();
for (int i = 1; i < candidates.length; i++) {
SplitCandidate surrogate = candidates[i];
TreeAttributeColumnData surrogateCol = surrogate.getColumnData();
TreeNodeCondition[] surrogateChildConditions = surrogate.getChildConditions();
if (surrogateChildConditions.length != 2) {
throw new IllegalArgumentException("Surrogates can only be calculated for binary splits.");
}
BitSet surrogateLeft = surrogateCol.updateChildMemberships(surrogateChildConditions[0], dataMemberships);
BitSet surrogateRight = surrogateCol.updateChildMemberships(surrogateChildConditions[1], dataMemberships);
BitSet bothLeft = (BitSet) bestSplitLeft.clone();
bothLeft.and(surrogateLeft);
BitSet bothRight = (BitSet) bestSplitRight.clone();
bothRight.and(surrogateRight);
// the complement of a split (switching the children) has the same gain value as the original split
BitSet complementBothLeft = (BitSet) bestSplitLeft.clone();
complementBothLeft.and(surrogateRight);
BitSet complementBothRight = (BitSet) bestSplitRight.clone();
complementBothRight.and(surrogateLeft);
// calculating the probability that the surrogate candidate and the best split send a case both in the same
// direction is necessary because there might be missing values which are not send in either direction
double probBothLeft = (bothLeft.cardinality() / numRowsInNode);
double probBothRight = (bothRight.cardinality() / numRowsInNode);
// the relative probability that the surrogate predicts the best split correctly
double predictProb = probBothLeft + probBothRight;
double probComplementBothLeft = (complementBothLeft.cardinality() / numRowsInNode);
double probComplementBothRight = (complementBothRight.cardinality() / numRowsInNode);
double complementPredictProb = probComplementBothLeft + probComplementBothRight;
double associationMeasure = calculateAssociationMeasure(errorMajorityRule, predictProb);
double complementAssociationMeasure = calculateAssociationMeasure(errorMajorityRule, complementPredictProb);
boolean useComplement = complementAssociationMeasure > associationMeasure ? true : false;
double betterAssociationMeasure = useComplement ? complementAssociationMeasure : associationMeasure;
assert betterAssociationMeasure <= 1 : "Association measure can not be greater than 1.";
if (betterAssociationMeasure > 0) {
BitSet[] childMarkers = new BitSet[] { surrogateLeft, surrogateRight };
surrogateCandidates.add(new SurrogateCandidate(surrogate, useComplement, betterAssociationMeasure, childMarkers));
}
}
BitSet[] childMarkers = new BitSet[] { bestSplitLeft, bestSplitRight };
// if there are no surrogates, create condition with default rule as only surrogate
if (surrogateCandidates.isEmpty()) {
fillInMissingChildMarkers(bestSplit, childMarkers, surrogateCandidates, majorityGoesLeft);
return new SurrogateSplit(new AbstractTreeNodeSurrogateCondition[] { new TreeNodeSurrogateOnlyDefDirCondition((TreeNodeColumnCondition) bestSplitChildConditions[0], majorityGoesLeft), new TreeNodeSurrogateOnlyDefDirCondition((TreeNodeColumnCondition) bestSplitChildConditions[1], !majorityGoesLeft) }, childMarkers);
}
surrogateCandidates.sort(null);
int condSize = surrogateCandidates.size() + 1;
TreeNodeColumnCondition[] conditionsLeftChild = new TreeNodeColumnCondition[condSize];
TreeNodeColumnCondition[] conditionsRightChild = new TreeNodeColumnCondition[condSize];
conditionsLeftChild[0] = (TreeNodeColumnCondition) bestSplitChildConditions[0];
conditionsRightChild[0] = (TreeNodeColumnCondition) bestSplitChildConditions[1];
for (int i = 0; i < surrogateCandidates.size(); i++) {
SurrogateCandidate surrogateCandidate = surrogateCandidates.get(i);
TreeNodeCondition[] surrogateConditions = surrogateCandidate.getSplitCandidate().getChildConditions();
if (surrogateCandidate.m_useComplement) {
conditionsLeftChild[i + 1] = (TreeNodeColumnCondition) surrogateConditions[1];
conditionsRightChild[i + 1] = (TreeNodeColumnCondition) surrogateConditions[0];
} else {
conditionsLeftChild[i + 1] = (TreeNodeColumnCondition) surrogateConditions[0];
conditionsRightChild[i + 1] = (TreeNodeColumnCondition) surrogateConditions[1];
}
}
// check if there are any rows missing in the best split
if (!bestSplit.getMissedRows().isEmpty()) {
// fill in any missing child markers
fillInMissingChildMarkers(bestSplit, childMarkers, surrogateCandidates, majorityGoesLeft);
}
return new SurrogateSplit(new TreeNodeSurrogateCondition[] { new TreeNodeSurrogateCondition(conditionsLeftChild, majorityGoesLeft), new TreeNodeSurrogateCondition(conditionsRightChild, !majorityGoesLeft) }, childMarkers);
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSurrogateCondition in project knime-core by knime.
the class LiteralConditionParser method parseSurrogateCompound.
private AbstractTreeNodeSurrogateCondition parseSurrogateCompound(final CompoundPredicate compound) {
// PMML requires us to realize surrogates as a chain of compound condition because it doesn't enforce an order
// among the predicates in the surrogate condition
List<TreeNodeColumnCondition> conds = new ArrayList<>();
boolean defaultResponse = unpackSurrogateChainIntoList(compound, conds);
CheckUtils.checkArgument(!conds.isEmpty(), "The surrogate conditon '%s' contains no column conditions.", compound);
return conds.size() > 1 ? new TreeNodeSurrogateCondition(conds.toArray(new TreeNodeColumnCondition[conds.size()]), defaultResponse) : new TreeNodeSurrogateOnlyDefDirCondition(conds.get(0), defaultResponse);
}
Aggregations