use of org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships in project knime-core by knime.
the class Surrogates method calculateSurrogates.
/**
* This function finds the splits (in <b>candidates</b>) that best mirror the best split (<b>candidates[0]</b>). The
* splits are compared to the so called <i>majority split</i> that sends all records to the child that the most rows
* in the best split are sent to. This <i>majority split</i> is also always the last surrogate to guarantee that
* every record is sent to a child even if all surrogate attributes are also missing.
*
* @param dataMemberships
* @param candidates the first candidate must be the best split
* @return A SplitCandidate containing surrogates
*/
public static SurrogateSplit calculateSurrogates(final DataMemberships dataMemberships, final SplitCandidate[] candidates) {
final SplitCandidate bestSplit = candidates[0];
TreeAttributeColumnData bestSplitCol = bestSplit.getColumnData();
TreeNodeCondition[] bestSplitChildConditions = bestSplit.getChildConditions();
if (bestSplitChildConditions.length != 2) {
throw new IllegalArgumentException("Surrogates can only be calculated for binary splits.");
}
BitSet bestSplitLeft = bestSplitCol.updateChildMemberships(bestSplitChildConditions[0], dataMemberships);
BitSet bestSplitRight = bestSplitCol.updateChildMemberships(bestSplitChildConditions[1], dataMemberships);
final double numRowsInNode = dataMemberships.getRowCount();
// probability for a row to be in the current node
final double probInNode = numRowsInNode / dataMemberships.getRowCountInRoot();
// probability for a row to go left according to the best split
final double bestSplitProbLeft = bestSplitLeft.cardinality() / numRowsInNode;
// probability for a row to go right according to the best split
final double bestSplitProbRight = bestSplitRight.cardinality() / numRowsInNode;
// the majority rule is always the last surrogate and defines a default direction if all other
// surrogates fail
final boolean majorityGoesLeft = bestSplitProbRight > bestSplitProbLeft ? false : true;
// see calculatAssociationMeasure() for more information
final double errorMajorityRule = majorityGoesLeft ? bestSplitProbRight : bestSplitProbLeft;
// stores association measure for candidates
ArrayList<SurrogateCandidate> surrogateCandidates = new ArrayList<SurrogateCandidate>();
for (int i = 1; i < candidates.length; i++) {
SplitCandidate surrogate = candidates[i];
TreeAttributeColumnData surrogateCol = surrogate.getColumnData();
TreeNodeCondition[] surrogateChildConditions = surrogate.getChildConditions();
if (surrogateChildConditions.length != 2) {
throw new IllegalArgumentException("Surrogates can only be calculated for binary splits.");
}
BitSet surrogateLeft = surrogateCol.updateChildMemberships(surrogateChildConditions[0], dataMemberships);
BitSet surrogateRight = surrogateCol.updateChildMemberships(surrogateChildConditions[1], dataMemberships);
BitSet bothLeft = (BitSet) bestSplitLeft.clone();
bothLeft.and(surrogateLeft);
BitSet bothRight = (BitSet) bestSplitRight.clone();
bothRight.and(surrogateRight);
// the complement of a split (switching the children) has the same gain value as the original split
BitSet complementBothLeft = (BitSet) bestSplitLeft.clone();
complementBothLeft.and(surrogateRight);
BitSet complementBothRight = (BitSet) bestSplitRight.clone();
complementBothRight.and(surrogateLeft);
// calculating the probability that the surrogate candidate and the best split send a case both in the same
// direction is necessary because there might be missing values which are not send in either direction
double probBothLeft = (bothLeft.cardinality() / numRowsInNode);
double probBothRight = (bothRight.cardinality() / numRowsInNode);
// the relative probability that the surrogate predicts the best split correctly
double predictProb = probBothLeft + probBothRight;
double probComplementBothLeft = (complementBothLeft.cardinality() / numRowsInNode);
double probComplementBothRight = (complementBothRight.cardinality() / numRowsInNode);
double complementPredictProb = probComplementBothLeft + probComplementBothRight;
double associationMeasure = calculateAssociationMeasure(errorMajorityRule, predictProb);
double complementAssociationMeasure = calculateAssociationMeasure(errorMajorityRule, complementPredictProb);
boolean useComplement = complementAssociationMeasure > associationMeasure ? true : false;
double betterAssociationMeasure = useComplement ? complementAssociationMeasure : associationMeasure;
assert betterAssociationMeasure <= 1 : "Association measure can not be greater than 1.";
if (betterAssociationMeasure > 0) {
BitSet[] childMarkers = new BitSet[] { surrogateLeft, surrogateRight };
surrogateCandidates.add(new SurrogateCandidate(surrogate, useComplement, betterAssociationMeasure, childMarkers));
}
}
BitSet[] childMarkers = new BitSet[] { bestSplitLeft, bestSplitRight };
// if there are no surrogates, create condition with default rule as only surrogate
if (surrogateCandidates.isEmpty()) {
fillInMissingChildMarkers(bestSplit, childMarkers, surrogateCandidates, majorityGoesLeft);
return new SurrogateSplit(new AbstractTreeNodeSurrogateCondition[] { new TreeNodeSurrogateOnlyDefDirCondition((TreeNodeColumnCondition) bestSplitChildConditions[0], majorityGoesLeft), new TreeNodeSurrogateOnlyDefDirCondition((TreeNodeColumnCondition) bestSplitChildConditions[1], !majorityGoesLeft) }, childMarkers);
}
surrogateCandidates.sort(null);
int condSize = surrogateCandidates.size() + 1;
TreeNodeColumnCondition[] conditionsLeftChild = new TreeNodeColumnCondition[condSize];
TreeNodeColumnCondition[] conditionsRightChild = new TreeNodeColumnCondition[condSize];
conditionsLeftChild[0] = (TreeNodeColumnCondition) bestSplitChildConditions[0];
conditionsRightChild[0] = (TreeNodeColumnCondition) bestSplitChildConditions[1];
for (int i = 0; i < surrogateCandidates.size(); i++) {
SurrogateCandidate surrogateCandidate = surrogateCandidates.get(i);
TreeNodeCondition[] surrogateConditions = surrogateCandidate.getSplitCandidate().getChildConditions();
if (surrogateCandidate.m_useComplement) {
conditionsLeftChild[i + 1] = (TreeNodeColumnCondition) surrogateConditions[1];
conditionsRightChild[i + 1] = (TreeNodeColumnCondition) surrogateConditions[0];
} else {
conditionsLeftChild[i + 1] = (TreeNodeColumnCondition) surrogateConditions[0];
conditionsRightChild[i + 1] = (TreeNodeColumnCondition) surrogateConditions[1];
}
}
// check if there are any rows missing in the best split
if (!bestSplit.getMissedRows().isEmpty()) {
// fill in any missing child markers
fillInMissingChildMarkers(bestSplit, childMarkers, surrogateCandidates, majorityGoesLeft);
}
return new SurrogateSplit(new TreeNodeSurrogateCondition[] { new TreeNodeSurrogateCondition(conditionsLeftChild, majorityGoesLeft), new TreeNodeSurrogateCondition(conditionsRightChild, !majorityGoesLeft) }, childMarkers);
}
use of org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships in project knime-core by knime.
the class Surrogates method createSurrogateSplitWithDefaultDirection.
/**
* Creates a surrogate split that only contains the best split and the default (majority) direction. It does
* <b>NOT</b> calculate any surrogate splits (and is therefore more efficient).
*
* @param dataMemberships
* @param bestSplit
* @return SurrogateSplit with conditions for both children. The conditions only contain the condition for the best
* split and the default condition (true for the child the most records go to and false for the other one).
*/
public static SurrogateSplit createSurrogateSplitWithDefaultDirection(final DataMemberships dataMemberships, final SplitCandidate bestSplit) {
TreeAttributeColumnData col = bestSplit.getColumnData();
TreeNodeCondition[] conditions = bestSplit.getChildConditions();
// get child marker for best split
BitSet left = col.updateChildMemberships(conditions[0], dataMemberships);
BitSet right = col.updateChildMemberships(conditions[1], dataMemberships);
// decide which child the majority of the records goes to
boolean majorityGoesLeft = left.cardinality() < right.cardinality() ? false : true;
// create surrogate conditions
TreeNodeSurrogateOnlyDefDirCondition condLeft = new TreeNodeSurrogateOnlyDefDirCondition((TreeNodeColumnCondition) conditions[0], majorityGoesLeft);
TreeNodeSurrogateOnlyDefDirCondition condRight = new TreeNodeSurrogateOnlyDefDirCondition((TreeNodeColumnCondition) conditions[1], !majorityGoesLeft);
BitSet[] childMarkers = new BitSet[] { left, right };
fillInMissingChildMarkersWithDefault(bestSplit, childMarkers, majorityGoesLeft);
return new SurrogateSplit(new AbstractTreeNodeSurrogateCondition[] { condLeft, condRight }, new BitSet[] { left, right });
}
use of org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships in project knime-core by knime.
the class TreeLearnerClassification method findBestSplitClassification.
private SplitCandidate findBestSplitClassification(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) {
final TreeData data = getData();
final RandomData rd = getRandomData();
// final ColumnSampleStrategy colSamplingStrategy = getColSamplingStrategy();
final TreeEnsembleLearnerConfiguration config = getConfig();
final int maxLevels = config.getMaxLevels();
if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
return null;
}
final int minNodeSize = config.getMinNodeSize();
if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
if (targetPriors.getNrRecords() < minNodeSize) {
return null;
}
}
final double priorImpurity = targetPriors.getPriorImpurity();
if (priorImpurity < TreeColumnData.EPSILON) {
return null;
}
final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
SplitCandidate splitCandidate = null;
if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
// TODO discuss whether this option makes sense with surrogates
return rootColumn.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
}
double bestGainValue = 0.0;
for (TreeAttributeColumnData col : columnSample) {
if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
continue;
}
final SplitCandidate currentColSplit = col.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
if (currentColSplit != null) {
final double currentGain = currentColSplit.getGainValue();
final boolean tiebreaker = currentGain == bestGainValue ? (rd.nextInt(0, 1) == 0) : false;
if (currentColSplit.getGainValue() > bestGainValue || tiebreaker) {
splitCandidate = currentColSplit;
bestGainValue = currentGain;
}
}
}
return splitCandidate;
}
use of org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships in project knime-core by knime.
the class TreeLearnerClassification method findBestSplitsClassification.
/**
* Returns a list of SplitCandidates sorted (descending) by their gain
*
* @param currentDepth
* @param rowSampleWeights
* @param treeNodeSignature
* @param targetPriors
* @param forbiddenColumnSet
* @param membershipController
* @return
*/
private SplitCandidate[] findBestSplitsClassification(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) {
final TreeData data = getData();
final RandomData rd = getRandomData();
// final ColumnSampleStrategy colSamplingStrategy = getColSamplingStrategy();
final TreeEnsembleLearnerConfiguration config = getConfig();
final int maxLevels = config.getMaxLevels();
if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
return null;
}
final int minNodeSize = config.getMinNodeSize();
if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
if (targetPriors.getNrRecords() < minNodeSize) {
return null;
}
}
final double priorImpurity = targetPriors.getPriorImpurity();
if (priorImpurity < TreeColumnData.EPSILON) {
return null;
}
final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
SplitCandidate splitCandidate = null;
if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
// TODO discuss whether this option makes sense with surrogates
return new SplitCandidate[] { rootColumn.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd) };
}
double bestGainValue = 0.0;
final Comparator<SplitCandidate> comp = new Comparator<SplitCandidate>() {
@Override
public int compare(final SplitCandidate o1, final SplitCandidate o2) {
int compareDouble = -Double.compare(o1.getGainValue(), o2.getGainValue());
return compareDouble;
}
};
ArrayList<SplitCandidate> candidates = new ArrayList<SplitCandidate>(columnSample.getNumCols());
for (TreeAttributeColumnData col : columnSample) {
if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
continue;
}
SplitCandidate currentColSplit = col.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
if (currentColSplit != null) {
candidates.add(currentColSplit);
}
}
if (candidates.isEmpty()) {
return null;
}
candidates.sort(comp);
return candidates.toArray(new SplitCandidate[candidates.size()]);
}
use of org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships in project knime-core by knime.
the class TreeLearnerClassification method learnSingleTreeRecursive.
private TreeModelClassification learnSingleTreeRecursive(final ExecutionMonitor exec, final RandomData rd) throws CanceledExecutionException {
final TreeData data = getData();
final RowSample rowSampling = getRowSampling();
final TreeEnsembleLearnerConfiguration config = getConfig();
final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
final // new RootDataMem(rowSampling, getIndexManager());
DataMemberships rootDataMemberships = new RootDataMemberships(rowSampling, data, getIndexManager());
ClassificationPriors targetPriors = targetColumn.getDistribution(rootDataMemberships, config);
BitSet forbiddenColumnSet = new BitSet(data.getNrAttributes());
// final DataMemberships rootDataMemberships = new IntArrayDataMemberships(sampleWeights, data);
final TreeNodeSignature rootSignature = TreeNodeSignature.ROOT_SIGNATURE;
final ColumnSample rootColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(rootSignature);
TreeNodeClassification rootNode = null;
rootNode = buildTreeNode(exec, 0, rootDataMemberships, rootColumnSample, rootSignature, targetPriors, forbiddenColumnSet);
assert forbiddenColumnSet.cardinality() == 0;
rootNode.setTreeNodeCondition(TreeNodeTrueCondition.INSTANCE);
return new TreeModelClassification(rootNode);
}
Aggregations