Search in sources :

Example 16 with DataMemberships

use of org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships in project knime-core by knime.

the class Surrogates method calculateSurrogates.

/**
 * This function finds the splits (in <b>candidates</b>) that best mirror the best split (<b>candidates[0]</b>). The
 * splits are compared to the so called <i>majority split</i> that sends all records to the child that the most rows
 * in the best split are sent to. This <i>majority split</i> is also always the last surrogate to guarantee that
 * every record is sent to a child even if all surrogate attributes are also missing.
 *
 * @param dataMemberships
 * @param candidates the first candidate must be the best split
 * @return A SplitCandidate containing surrogates
 */
public static SurrogateSplit calculateSurrogates(final DataMemberships dataMemberships, final SplitCandidate[] candidates) {
    final SplitCandidate bestSplit = candidates[0];
    TreeAttributeColumnData bestSplitCol = bestSplit.getColumnData();
    TreeNodeCondition[] bestSplitChildConditions = bestSplit.getChildConditions();
    if (bestSplitChildConditions.length != 2) {
        throw new IllegalArgumentException("Surrogates can only be calculated for binary splits.");
    }
    BitSet bestSplitLeft = bestSplitCol.updateChildMemberships(bestSplitChildConditions[0], dataMemberships);
    BitSet bestSplitRight = bestSplitCol.updateChildMemberships(bestSplitChildConditions[1], dataMemberships);
    final double numRowsInNode = dataMemberships.getRowCount();
    // probability for a row to be in the current node
    final double probInNode = numRowsInNode / dataMemberships.getRowCountInRoot();
    // probability for a row to go left according to the best split
    final double bestSplitProbLeft = bestSplitLeft.cardinality() / numRowsInNode;
    // probability for a row to go right according to the best split
    final double bestSplitProbRight = bestSplitRight.cardinality() / numRowsInNode;
    // the majority rule is always the last surrogate and defines a default direction if all other
    // surrogates fail
    final boolean majorityGoesLeft = bestSplitProbRight > bestSplitProbLeft ? false : true;
    // see calculatAssociationMeasure() for more information
    final double errorMajorityRule = majorityGoesLeft ? bestSplitProbRight : bestSplitProbLeft;
    // stores association measure for candidates
    ArrayList<SurrogateCandidate> surrogateCandidates = new ArrayList<SurrogateCandidate>();
    for (int i = 1; i < candidates.length; i++) {
        SplitCandidate surrogate = candidates[i];
        TreeAttributeColumnData surrogateCol = surrogate.getColumnData();
        TreeNodeCondition[] surrogateChildConditions = surrogate.getChildConditions();
        if (surrogateChildConditions.length != 2) {
            throw new IllegalArgumentException("Surrogates can only be calculated for binary splits.");
        }
        BitSet surrogateLeft = surrogateCol.updateChildMemberships(surrogateChildConditions[0], dataMemberships);
        BitSet surrogateRight = surrogateCol.updateChildMemberships(surrogateChildConditions[1], dataMemberships);
        BitSet bothLeft = (BitSet) bestSplitLeft.clone();
        bothLeft.and(surrogateLeft);
        BitSet bothRight = (BitSet) bestSplitRight.clone();
        bothRight.and(surrogateRight);
        // the complement of a split (switching the children) has the same gain value as the original split
        BitSet complementBothLeft = (BitSet) bestSplitLeft.clone();
        complementBothLeft.and(surrogateRight);
        BitSet complementBothRight = (BitSet) bestSplitRight.clone();
        complementBothRight.and(surrogateLeft);
        // calculating the probability that the surrogate candidate and the best split send a case both in the same
        // direction is necessary because there might be missing values which are not send in either direction
        double probBothLeft = (bothLeft.cardinality() / numRowsInNode);
        double probBothRight = (bothRight.cardinality() / numRowsInNode);
        // the relative probability that the surrogate predicts the best split correctly
        double predictProb = probBothLeft + probBothRight;
        double probComplementBothLeft = (complementBothLeft.cardinality() / numRowsInNode);
        double probComplementBothRight = (complementBothRight.cardinality() / numRowsInNode);
        double complementPredictProb = probComplementBothLeft + probComplementBothRight;
        double associationMeasure = calculateAssociationMeasure(errorMajorityRule, predictProb);
        double complementAssociationMeasure = calculateAssociationMeasure(errorMajorityRule, complementPredictProb);
        boolean useComplement = complementAssociationMeasure > associationMeasure ? true : false;
        double betterAssociationMeasure = useComplement ? complementAssociationMeasure : associationMeasure;
        assert betterAssociationMeasure <= 1 : "Association measure can not be greater than 1.";
        if (betterAssociationMeasure > 0) {
            BitSet[] childMarkers = new BitSet[] { surrogateLeft, surrogateRight };
            surrogateCandidates.add(new SurrogateCandidate(surrogate, useComplement, betterAssociationMeasure, childMarkers));
        }
    }
    BitSet[] childMarkers = new BitSet[] { bestSplitLeft, bestSplitRight };
    // if there are no surrogates, create condition with default rule as only surrogate
    if (surrogateCandidates.isEmpty()) {
        fillInMissingChildMarkers(bestSplit, childMarkers, surrogateCandidates, majorityGoesLeft);
        return new SurrogateSplit(new AbstractTreeNodeSurrogateCondition[] { new TreeNodeSurrogateOnlyDefDirCondition((TreeNodeColumnCondition) bestSplitChildConditions[0], majorityGoesLeft), new TreeNodeSurrogateOnlyDefDirCondition((TreeNodeColumnCondition) bestSplitChildConditions[1], !majorityGoesLeft) }, childMarkers);
    }
    surrogateCandidates.sort(null);
    int condSize = surrogateCandidates.size() + 1;
    TreeNodeColumnCondition[] conditionsLeftChild = new TreeNodeColumnCondition[condSize];
    TreeNodeColumnCondition[] conditionsRightChild = new TreeNodeColumnCondition[condSize];
    conditionsLeftChild[0] = (TreeNodeColumnCondition) bestSplitChildConditions[0];
    conditionsRightChild[0] = (TreeNodeColumnCondition) bestSplitChildConditions[1];
    for (int i = 0; i < surrogateCandidates.size(); i++) {
        SurrogateCandidate surrogateCandidate = surrogateCandidates.get(i);
        TreeNodeCondition[] surrogateConditions = surrogateCandidate.getSplitCandidate().getChildConditions();
        if (surrogateCandidate.m_useComplement) {
            conditionsLeftChild[i + 1] = (TreeNodeColumnCondition) surrogateConditions[1];
            conditionsRightChild[i + 1] = (TreeNodeColumnCondition) surrogateConditions[0];
        } else {
            conditionsLeftChild[i + 1] = (TreeNodeColumnCondition) surrogateConditions[0];
            conditionsRightChild[i + 1] = (TreeNodeColumnCondition) surrogateConditions[1];
        }
    }
    // check if there are any rows missing in the best split
    if (!bestSplit.getMissedRows().isEmpty()) {
        // fill in any missing child markers
        fillInMissingChildMarkers(bestSplit, childMarkers, surrogateCandidates, majorityGoesLeft);
    }
    return new SurrogateSplit(new TreeNodeSurrogateCondition[] { new TreeNodeSurrogateCondition(conditionsLeftChild, majorityGoesLeft), new TreeNodeSurrogateCondition(conditionsRightChild, !majorityGoesLeft) }, childMarkers);
}
Also used : TreeNodeSurrogateCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeSurrogateCondition) AbstractTreeNodeSurrogateCondition(org.knime.base.node.mine.treeensemble2.model.AbstractTreeNodeSurrogateCondition) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) BitSet(java.util.BitSet) ArrayList(java.util.ArrayList) TreeNodeSurrogateOnlyDefDirCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeSurrogateOnlyDefDirCondition) TreeNodeColumnCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeColumnCondition) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)

Example 17 with DataMemberships

use of org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships in project knime-core by knime.

the class Surrogates method createSurrogateSplitWithDefaultDirection.

/**
 * Creates a surrogate split that only contains the best split and the default (majority) direction. It does
 * <b>NOT</b> calculate any surrogate splits (and is therefore more efficient).
 *
 * @param dataMemberships
 * @param bestSplit
 * @return SurrogateSplit with conditions for both children. The conditions only contain the condition for the best
 *         split and the default condition (true for the child the most records go to and false for the other one).
 */
public static SurrogateSplit createSurrogateSplitWithDefaultDirection(final DataMemberships dataMemberships, final SplitCandidate bestSplit) {
    TreeAttributeColumnData col = bestSplit.getColumnData();
    TreeNodeCondition[] conditions = bestSplit.getChildConditions();
    // get child marker for best split
    BitSet left = col.updateChildMemberships(conditions[0], dataMemberships);
    BitSet right = col.updateChildMemberships(conditions[1], dataMemberships);
    // decide which child the majority of the records goes to
    boolean majorityGoesLeft = left.cardinality() < right.cardinality() ? false : true;
    // create surrogate conditions
    TreeNodeSurrogateOnlyDefDirCondition condLeft = new TreeNodeSurrogateOnlyDefDirCondition((TreeNodeColumnCondition) conditions[0], majorityGoesLeft);
    TreeNodeSurrogateOnlyDefDirCondition condRight = new TreeNodeSurrogateOnlyDefDirCondition((TreeNodeColumnCondition) conditions[1], !majorityGoesLeft);
    BitSet[] childMarkers = new BitSet[] { left, right };
    fillInMissingChildMarkersWithDefault(bestSplit, childMarkers, majorityGoesLeft);
    return new SurrogateSplit(new AbstractTreeNodeSurrogateCondition[] { condLeft, condRight }, new BitSet[] { left, right });
}
Also used : TreeNodeSurrogateOnlyDefDirCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeSurrogateOnlyDefDirCondition) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) BitSet(java.util.BitSet) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)

Example 18 with DataMemberships

use of org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships in project knime-core by knime.

the class TreeLearnerClassification method findBestSplitClassification.

private SplitCandidate findBestSplitClassification(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) {
    final TreeData data = getData();
    final RandomData rd = getRandomData();
    // final ColumnSampleStrategy colSamplingStrategy = getColSamplingStrategy();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final int maxLevels = config.getMaxLevels();
    if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
        return null;
    }
    final int minNodeSize = config.getMinNodeSize();
    if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
        if (targetPriors.getNrRecords() < minNodeSize) {
            return null;
        }
    }
    final double priorImpurity = targetPriors.getPriorImpurity();
    if (priorImpurity < TreeColumnData.EPSILON) {
        return null;
    }
    final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
    SplitCandidate splitCandidate = null;
    if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
        final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
        // TODO discuss whether this option makes sense with surrogates
        return rootColumn.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
    }
    double bestGainValue = 0.0;
    for (TreeAttributeColumnData col : columnSample) {
        if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
            continue;
        }
        final SplitCandidate currentColSplit = col.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
        if (currentColSplit != null) {
            final double currentGain = currentColSplit.getGainValue();
            final boolean tiebreaker = currentGain == bestGainValue ? (rd.nextInt(0, 1) == 0) : false;
            if (currentColSplit.getGainValue() > bestGainValue || tiebreaker) {
                splitCandidate = currentColSplit;
                bestGainValue = currentGain;
            }
        }
    }
    return splitCandidate;
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeTargetNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)

Example 19 with DataMemberships

use of org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships in project knime-core by knime.

the class TreeLearnerClassification method findBestSplitsClassification.

/**
 * Returns a list of SplitCandidates sorted (descending) by their gain
 *
 * @param currentDepth
 * @param rowSampleWeights
 * @param treeNodeSignature
 * @param targetPriors
 * @param forbiddenColumnSet
 * @param membershipController
 * @return
 */
private SplitCandidate[] findBestSplitsClassification(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) {
    final TreeData data = getData();
    final RandomData rd = getRandomData();
    // final ColumnSampleStrategy colSamplingStrategy = getColSamplingStrategy();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final int maxLevels = config.getMaxLevels();
    if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
        return null;
    }
    final int minNodeSize = config.getMinNodeSize();
    if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
        if (targetPriors.getNrRecords() < minNodeSize) {
            return null;
        }
    }
    final double priorImpurity = targetPriors.getPriorImpurity();
    if (priorImpurity < TreeColumnData.EPSILON) {
        return null;
    }
    final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
    SplitCandidate splitCandidate = null;
    if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
        final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
        // TODO discuss whether this option makes sense with surrogates
        return new SplitCandidate[] { rootColumn.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd) };
    }
    double bestGainValue = 0.0;
    final Comparator<SplitCandidate> comp = new Comparator<SplitCandidate>() {

        @Override
        public int compare(final SplitCandidate o1, final SplitCandidate o2) {
            int compareDouble = -Double.compare(o1.getGainValue(), o2.getGainValue());
            return compareDouble;
        }
    };
    ArrayList<SplitCandidate> candidates = new ArrayList<SplitCandidate>(columnSample.getNumCols());
    for (TreeAttributeColumnData col : columnSample) {
        if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
            continue;
        }
        SplitCandidate currentColSplit = col.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
        if (currentColSplit != null) {
            candidates.add(currentColSplit);
        }
    }
    if (candidates.isEmpty()) {
        return null;
    }
    candidates.sort(comp);
    return candidates.toArray(new SplitCandidate[candidates.size()]);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) ArrayList(java.util.ArrayList) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeTargetNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData) Comparator(java.util.Comparator)

Example 20 with DataMemberships

use of org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships in project knime-core by knime.

the class TreeLearnerClassification method learnSingleTreeRecursive.

private TreeModelClassification learnSingleTreeRecursive(final ExecutionMonitor exec, final RandomData rd) throws CanceledExecutionException {
    final TreeData data = getData();
    final RowSample rowSampling = getRowSampling();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
    final // new RootDataMem(rowSampling, getIndexManager());
    DataMemberships rootDataMemberships = new RootDataMemberships(rowSampling, data, getIndexManager());
    ClassificationPriors targetPriors = targetColumn.getDistribution(rootDataMemberships, config);
    BitSet forbiddenColumnSet = new BitSet(data.getNrAttributes());
    // final DataMemberships rootDataMemberships = new IntArrayDataMemberships(sampleWeights, data);
    final TreeNodeSignature rootSignature = TreeNodeSignature.ROOT_SIGNATURE;
    final ColumnSample rootColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(rootSignature);
    TreeNodeClassification rootNode = null;
    rootNode = buildTreeNode(exec, 0, rootDataMemberships, rootColumnSample, rootSignature, targetPriors, forbiddenColumnSet);
    assert forbiddenColumnSet.cardinality() == 0;
    rootNode.setTreeNodeCondition(TreeNodeTrueCondition.INSTANCE);
    return new TreeModelClassification(rootNode);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeClassification(org.knime.base.node.mine.treeensemble2.model.TreeNodeClassification) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) BitSet(java.util.BitSet) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeTargetNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData) ClassificationPriors(org.knime.base.node.mine.treeensemble2.data.ClassificationPriors) TreeModelClassification(org.knime.base.node.mine.treeensemble2.model.TreeModelClassification)

Aggregations

TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)34 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)26 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)25 BitSet (java.util.BitSet)21 Test (org.junit.Test)21 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)17 RandomData (org.apache.commons.math.random.RandomData)15 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)14 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)13 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)12 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)12 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)10 ColumnMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships)10 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)9 TreeNodeNominalBinaryCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition)9 NumericSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate)7 TreeNodeNumericCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition)7 TreeNodeCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)6 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)5 NumericMissingSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate)5