use of org.knime.base.node.mine.treeensemble2.model.TreeNodeColumnCondition in project knime-core by knime.
the class RandomForestClassificationTreeNodeWidget method getConnectorLabelBelow.
/**
* {@inheritDoc}
*/
@Override
public String getConnectorLabelBelow() {
TreeNodeClassification node = (TreeNodeClassification) getUserObject();
if (node.getNrChildren() != 0) {
TreeNodeClassification child = node.getChild(0);
TreeNodeCondition childCondition = child.getCondition();
if (childCondition instanceof TreeNodeColumnCondition) {
return ((TreeNodeColumnCondition) childCondition).getAttributeName();
} else if (childCondition instanceof TreeNodeSurrogateCondition) {
TreeNodeSurrogateCondition surrogateCondition = (TreeNodeSurrogateCondition) childCondition;
TreeNodeCondition headCondition = surrogateCondition.getFirstCondition();
if (headCondition instanceof TreeNodeColumnCondition) {
return ((TreeNodeColumnCondition) headCondition).getAttributeName();
}
}
}
return null;
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeColumnCondition in project knime-core by knime.
the class LiteralConditionParser method handleSimpleSetPredicate.
private TreeNodeColumnCondition handleSimpleSetPredicate(final SimpleSetPredicate simpleSetPred, final boolean acceptsMissings) {
String field = simpleSetPred.getField();
CheckUtils.checkArgument(m_metaDataMapper.isNominal(field), "The field \"%s\" is not nominal but currently only nominal fields can be used for SimpleSetPredicates", field);
NominalAttributeColumnHelper colHelper = m_metaDataMapper.getNominalColumnHelper(field);
TreeNominalColumnMetaData metaData = colHelper.getMetaData();
boolean isInSet = simpleSetPred.getBooleanOperator().equals(SimpleSetPredicate.BooleanOperator.IS_IN);
return new TreeNodeNominalBinaryCondition(metaData, parseValuesMask(simpleSetPred, colHelper), isInSet, acceptsMissings);
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeColumnCondition in project knime-core by knime.
the class TreeModelPMMLTranslator method addTreeNode.
/**
* @param pmmlNode
* @param node
*/
private void addTreeNode(final Node pmmlNode, final AbstractTreeNode node) {
int index = m_nodeIndex++;
pmmlNode.setId(Integer.toString(index));
if (node instanceof TreeNodeClassification) {
final TreeNodeClassification clazNode = (TreeNodeClassification) node;
pmmlNode.setScore(clazNode.getMajorityClassName());
float[] targetDistribution = clazNode.getTargetDistribution();
NominalValueRepresentation[] targetVals = clazNode.getTargetMetaData().getValues();
double sum = 0.0;
for (Float v : targetDistribution) {
sum += v;
}
pmmlNode.setRecordCount(sum);
// adding score distribution (class counts)
for (int i = 0; i < targetDistribution.length; i++) {
String className = targetVals[i].getNominalValue();
double freq = targetDistribution[i];
ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
pmmlScoreDist.setValue(className);
pmmlScoreDist.setRecordCount(freq);
}
} else if (node instanceof TreeNodeRegression) {
final TreeNodeRegression regNode = (TreeNodeRegression) node;
pmmlNode.setScore(Double.toString(regNode.getMean()));
}
TreeNodeCondition condition = node.getCondition();
if (condition instanceof TreeNodeTrueCondition) {
pmmlNode.addNewTrue();
} else if (condition instanceof TreeNodeColumnCondition) {
final TreeNodeColumnCondition colCondition = (TreeNodeColumnCondition) condition;
handleColumnCondition(colCondition, pmmlNode);
} else if (condition instanceof AbstractTreeNodeSurrogateCondition) {
final AbstractTreeNodeSurrogateCondition surrogateCond = (AbstractTreeNodeSurrogateCondition) condition;
setValuesFromPMMLCompoundPredicate(pmmlNode.addNewCompoundPredicate(), surrogateCond.toPMMLPredicate());
} else {
throw new IllegalStateException("Unsupported condition (not " + "implemented): " + condition.getClass().getSimpleName());
}
for (int i = 0; i < node.getNrChildren(); i++) {
addTreeNode(pmmlNode.addNewNode(), node.getChild(i));
}
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeColumnCondition in project knime-core by knime.
the class Surrogates method calculateSurrogates.
/**
* This function finds the splits (in <b>candidates</b>) that best mirror the best split (<b>candidates[0]</b>). The
* splits are compared to the so called <i>majority split</i> that sends all records to the child that the most rows
* in the best split are sent to. This <i>majority split</i> is also always the last surrogate to guarantee that
* every record is sent to a child even if all surrogate attributes are also missing.
*
* @param dataMemberships
* @param candidates the first candidate must be the best split
* @return A SplitCandidate containing surrogates
*/
public static SurrogateSplit calculateSurrogates(final DataMemberships dataMemberships, final SplitCandidate[] candidates) {
final SplitCandidate bestSplit = candidates[0];
TreeAttributeColumnData bestSplitCol = bestSplit.getColumnData();
TreeNodeCondition[] bestSplitChildConditions = bestSplit.getChildConditions();
if (bestSplitChildConditions.length != 2) {
throw new IllegalArgumentException("Surrogates can only be calculated for binary splits.");
}
BitSet bestSplitLeft = bestSplitCol.updateChildMemberships(bestSplitChildConditions[0], dataMemberships);
BitSet bestSplitRight = bestSplitCol.updateChildMemberships(bestSplitChildConditions[1], dataMemberships);
final double numRowsInNode = dataMemberships.getRowCount();
// probability for a row to be in the current node
final double probInNode = numRowsInNode / dataMemberships.getRowCountInRoot();
// probability for a row to go left according to the best split
final double bestSplitProbLeft = bestSplitLeft.cardinality() / numRowsInNode;
// probability for a row to go right according to the best split
final double bestSplitProbRight = bestSplitRight.cardinality() / numRowsInNode;
// the majority rule is always the last surrogate and defines a default direction if all other
// surrogates fail
final boolean majorityGoesLeft = bestSplitProbRight > bestSplitProbLeft ? false : true;
// see calculatAssociationMeasure() for more information
final double errorMajorityRule = majorityGoesLeft ? bestSplitProbRight : bestSplitProbLeft;
// stores association measure for candidates
ArrayList<SurrogateCandidate> surrogateCandidates = new ArrayList<SurrogateCandidate>();
for (int i = 1; i < candidates.length; i++) {
SplitCandidate surrogate = candidates[i];
TreeAttributeColumnData surrogateCol = surrogate.getColumnData();
TreeNodeCondition[] surrogateChildConditions = surrogate.getChildConditions();
if (surrogateChildConditions.length != 2) {
throw new IllegalArgumentException("Surrogates can only be calculated for binary splits.");
}
BitSet surrogateLeft = surrogateCol.updateChildMemberships(surrogateChildConditions[0], dataMemberships);
BitSet surrogateRight = surrogateCol.updateChildMemberships(surrogateChildConditions[1], dataMemberships);
BitSet bothLeft = (BitSet) bestSplitLeft.clone();
bothLeft.and(surrogateLeft);
BitSet bothRight = (BitSet) bestSplitRight.clone();
bothRight.and(surrogateRight);
// the complement of a split (switching the children) has the same gain value as the original split
BitSet complementBothLeft = (BitSet) bestSplitLeft.clone();
complementBothLeft.and(surrogateRight);
BitSet complementBothRight = (BitSet) bestSplitRight.clone();
complementBothRight.and(surrogateLeft);
// calculating the probability that the surrogate candidate and the best split send a case both in the same
// direction is necessary because there might be missing values which are not send in either direction
double probBothLeft = (bothLeft.cardinality() / numRowsInNode);
double probBothRight = (bothRight.cardinality() / numRowsInNode);
// the relative probability that the surrogate predicts the best split correctly
double predictProb = probBothLeft + probBothRight;
double probComplementBothLeft = (complementBothLeft.cardinality() / numRowsInNode);
double probComplementBothRight = (complementBothRight.cardinality() / numRowsInNode);
double complementPredictProb = probComplementBothLeft + probComplementBothRight;
double associationMeasure = calculateAssociationMeasure(errorMajorityRule, predictProb);
double complementAssociationMeasure = calculateAssociationMeasure(errorMajorityRule, complementPredictProb);
boolean useComplement = complementAssociationMeasure > associationMeasure ? true : false;
double betterAssociationMeasure = useComplement ? complementAssociationMeasure : associationMeasure;
assert betterAssociationMeasure <= 1 : "Association measure can not be greater than 1.";
if (betterAssociationMeasure > 0) {
BitSet[] childMarkers = new BitSet[] { surrogateLeft, surrogateRight };
surrogateCandidates.add(new SurrogateCandidate(surrogate, useComplement, betterAssociationMeasure, childMarkers));
}
}
BitSet[] childMarkers = new BitSet[] { bestSplitLeft, bestSplitRight };
// if there are no surrogates, create condition with default rule as only surrogate
if (surrogateCandidates.isEmpty()) {
fillInMissingChildMarkers(bestSplit, childMarkers, surrogateCandidates, majorityGoesLeft);
return new SurrogateSplit(new AbstractTreeNodeSurrogateCondition[] { new TreeNodeSurrogateOnlyDefDirCondition((TreeNodeColumnCondition) bestSplitChildConditions[0], majorityGoesLeft), new TreeNodeSurrogateOnlyDefDirCondition((TreeNodeColumnCondition) bestSplitChildConditions[1], !majorityGoesLeft) }, childMarkers);
}
surrogateCandidates.sort(null);
int condSize = surrogateCandidates.size() + 1;
TreeNodeColumnCondition[] conditionsLeftChild = new TreeNodeColumnCondition[condSize];
TreeNodeColumnCondition[] conditionsRightChild = new TreeNodeColumnCondition[condSize];
conditionsLeftChild[0] = (TreeNodeColumnCondition) bestSplitChildConditions[0];
conditionsRightChild[0] = (TreeNodeColumnCondition) bestSplitChildConditions[1];
for (int i = 0; i < surrogateCandidates.size(); i++) {
SurrogateCandidate surrogateCandidate = surrogateCandidates.get(i);
TreeNodeCondition[] surrogateConditions = surrogateCandidate.getSplitCandidate().getChildConditions();
if (surrogateCandidate.m_useComplement) {
conditionsLeftChild[i + 1] = (TreeNodeColumnCondition) surrogateConditions[1];
conditionsRightChild[i + 1] = (TreeNodeColumnCondition) surrogateConditions[0];
} else {
conditionsLeftChild[i + 1] = (TreeNodeColumnCondition) surrogateConditions[0];
conditionsRightChild[i + 1] = (TreeNodeColumnCondition) surrogateConditions[1];
}
}
// check if there are any rows missing in the best split
if (!bestSplit.getMissedRows().isEmpty()) {
// fill in any missing child markers
fillInMissingChildMarkers(bestSplit, childMarkers, surrogateCandidates, majorityGoesLeft);
}
return new SurrogateSplit(new TreeNodeSurrogateCondition[] { new TreeNodeSurrogateCondition(conditionsLeftChild, majorityGoesLeft), new TreeNodeSurrogateCondition(conditionsRightChild, !majorityGoesLeft) }, childMarkers);
}
use of org.knime.base.node.mine.treeensemble2.model.TreeNodeColumnCondition in project knime-core by knime.
the class Surrogates method createSurrogateSplitWithDefaultDirection.
/**
* Creates a surrogate split that only contains the best split and the default (majority) direction. It does
* <b>NOT</b> calculate any surrogate splits (and is therefore more efficient).
*
* @param dataMemberships
* @param bestSplit
* @return SurrogateSplit with conditions for both children. The conditions only contain the condition for the best
* split and the default condition (true for the child the most records go to and false for the other one).
*/
public static SurrogateSplit createSurrogateSplitWithDefaultDirection(final DataMemberships dataMemberships, final SplitCandidate bestSplit) {
TreeAttributeColumnData col = bestSplit.getColumnData();
TreeNodeCondition[] conditions = bestSplit.getChildConditions();
// get child marker for best split
BitSet left = col.updateChildMemberships(conditions[0], dataMemberships);
BitSet right = col.updateChildMemberships(conditions[1], dataMemberships);
// decide which child the majority of the records goes to
boolean majorityGoesLeft = left.cardinality() < right.cardinality() ? false : true;
// create surrogate conditions
TreeNodeSurrogateOnlyDefDirCondition condLeft = new TreeNodeSurrogateOnlyDefDirCondition((TreeNodeColumnCondition) conditions[0], majorityGoesLeft);
TreeNodeSurrogateOnlyDefDirCondition condRight = new TreeNodeSurrogateOnlyDefDirCondition((TreeNodeColumnCondition) conditions[1], !majorityGoesLeft);
BitSet[] childMarkers = new BitSet[] { left, right };
fillInMissingChildMarkersWithDefault(bestSplit, childMarkers, majorityGoesLeft);
return new SurrogateSplit(new AbstractTreeNodeSurrogateCondition[] { condLeft, condRight }, new BitSet[] { left, right });
}
Aggregations