Search in sources :

Example 11 with SplitCandidate

use of org.knime.base.node.mine.treeensemble2.learner.SplitCandidate in project knime-core by knime.

the class Surrogates method createSurrogateSplitWithDefaultDirection.

/**
 * Creates a surrogate split that only contains the best split and the default (majority) direction. It does
 * <b>NOT</b> calculate any surrogate splits (and is therefore more efficient).
 *
 * @param dataMemberships
 * @param bestSplit
 * @return SurrogateSplit with conditions for both children. The conditions only contain the condition for the best
 *         split and the default condition (true for the child the most records go to and false for the other one).
 */
public static SurrogateSplit createSurrogateSplitWithDefaultDirection(final DataMemberships dataMemberships, final SplitCandidate bestSplit) {
    TreeAttributeColumnData col = bestSplit.getColumnData();
    TreeNodeCondition[] conditions = bestSplit.getChildConditions();
    // get child marker for best split
    BitSet left = col.updateChildMemberships(conditions[0], dataMemberships);
    BitSet right = col.updateChildMemberships(conditions[1], dataMemberships);
    // decide which child the majority of the records goes to
    boolean majorityGoesLeft = left.cardinality() < right.cardinality() ? false : true;
    // create surrogate conditions
    TreeNodeSurrogateOnlyDefDirCondition condLeft = new TreeNodeSurrogateOnlyDefDirCondition((TreeNodeColumnCondition) conditions[0], majorityGoesLeft);
    TreeNodeSurrogateOnlyDefDirCondition condRight = new TreeNodeSurrogateOnlyDefDirCondition((TreeNodeColumnCondition) conditions[1], !majorityGoesLeft);
    BitSet[] childMarkers = new BitSet[] { left, right };
    fillInMissingChildMarkersWithDefault(bestSplit, childMarkers, majorityGoesLeft);
    return new SurrogateSplit(new AbstractTreeNodeSurrogateCondition[] { condLeft, condRight }, new BitSet[] { left, right });
}
Also used : TreeNodeSurrogateOnlyDefDirCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeSurrogateOnlyDefDirCondition) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) BitSet(java.util.BitSet) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)

Example 12 with SplitCandidate

use of org.knime.base.node.mine.treeensemble2.learner.SplitCandidate in project knime-core by knime.

the class TreeLearnerClassification method findBestSplitClassification.

private SplitCandidate findBestSplitClassification(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) {
    final TreeData data = getData();
    final RandomData rd = getRandomData();
    // final ColumnSampleStrategy colSamplingStrategy = getColSamplingStrategy();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final int maxLevels = config.getMaxLevels();
    if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
        return null;
    }
    final int minNodeSize = config.getMinNodeSize();
    if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
        if (targetPriors.getNrRecords() < minNodeSize) {
            return null;
        }
    }
    final double priorImpurity = targetPriors.getPriorImpurity();
    if (priorImpurity < TreeColumnData.EPSILON) {
        return null;
    }
    final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
    SplitCandidate splitCandidate = null;
    if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
        final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
        // TODO discuss whether this option makes sense with surrogates
        return rootColumn.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
    }
    double bestGainValue = 0.0;
    for (TreeAttributeColumnData col : columnSample) {
        if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
            continue;
        }
        final SplitCandidate currentColSplit = col.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
        if (currentColSplit != null) {
            final double currentGain = currentColSplit.getGainValue();
            final boolean tiebreaker = currentGain == bestGainValue ? (rd.nextInt(0, 1) == 0) : false;
            if (currentColSplit.getGainValue() > bestGainValue || tiebreaker) {
                splitCandidate = currentColSplit;
                bestGainValue = currentGain;
            }
        }
    }
    return splitCandidate;
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeTargetNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)

Example 13 with SplitCandidate

use of org.knime.base.node.mine.treeensemble2.learner.SplitCandidate in project knime-core by knime.

the class TreeLearnerClassification method findBestSplitsClassification.

/**
 * Returns a list of SplitCandidates sorted (descending) by their gain
 *
 * @param currentDepth
 * @param rowSampleWeights
 * @param treeNodeSignature
 * @param targetPriors
 * @param forbiddenColumnSet
 * @param membershipController
 * @return
 */
private SplitCandidate[] findBestSplitsClassification(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) {
    final TreeData data = getData();
    final RandomData rd = getRandomData();
    // final ColumnSampleStrategy colSamplingStrategy = getColSamplingStrategy();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final int maxLevels = config.getMaxLevels();
    if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
        return null;
    }
    final int minNodeSize = config.getMinNodeSize();
    if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
        if (targetPriors.getNrRecords() < minNodeSize) {
            return null;
        }
    }
    final double priorImpurity = targetPriors.getPriorImpurity();
    if (priorImpurity < TreeColumnData.EPSILON) {
        return null;
    }
    final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
    SplitCandidate splitCandidate = null;
    if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
        final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
        // TODO discuss whether this option makes sense with surrogates
        return new SplitCandidate[] { rootColumn.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd) };
    }
    double bestGainValue = 0.0;
    final Comparator<SplitCandidate> comp = new Comparator<SplitCandidate>() {

        @Override
        public int compare(final SplitCandidate o1, final SplitCandidate o2) {
            int compareDouble = -Double.compare(o1.getGainValue(), o2.getGainValue());
            return compareDouble;
        }
    };
    ArrayList<SplitCandidate> candidates = new ArrayList<SplitCandidate>(columnSample.getNumCols());
    for (TreeAttributeColumnData col : columnSample) {
        if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
            continue;
        }
        SplitCandidate currentColSplit = col.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
        if (currentColSplit != null) {
            candidates.add(currentColSplit);
        }
    }
    if (candidates.isEmpty()) {
        return null;
    }
    candidates.sort(comp);
    return candidates.toArray(new SplitCandidate[candidates.size()]);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) ArrayList(java.util.ArrayList) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeTargetNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData) Comparator(java.util.Comparator)

Example 14 with SplitCandidate

use of org.knime.base.node.mine.treeensemble2.learner.SplitCandidate in project knime-core by knime.

the class TreeNumericColumnDataTest method testCalcBestSplitClassification.

@Test
public void testCalcBestSplitClassification() throws Exception {
    TreeEnsembleLearnerConfiguration config = createConfig();
    /* data from J. Fuernkranz, Uni Darmstadt:
         * http://www.ke.tu-darmstadt.de/lehre/archiv/ws0809/mldm/dt.pdf */
    final double[] data = asDataArray("60,70,75,85, 90, 95, 100,120,125,220");
    final String[] target = asStringArray("No,No,No,Yes,Yes,Yes,No, No, No, No");
    Pair<TreeOrdinaryNumericColumnData, TreeTargetNominalColumnData> exampleData = exampleData(config, data, target);
    RandomData rd = config.createRandomData();
    TreeNumericColumnData columnData = exampleData.getFirst();
    TreeTargetNominalColumnData targetData = exampleData.getSecond();
    assertEquals(SplitCriterion.Gini, config.getSplitCriterion());
    double[] rowWeights = new double[data.length];
    Arrays.fill(rowWeights, 1.0);
    TreeData treeData = createTreeDataClassification(exampleData);
    IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, rd);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NumericSplitCandidate.class));
    assertTrue(splitCandidate.canColumnBeSplitFurther());
    // libre office calc
    assertEquals(/*0.42 - 0.300 */
    0.12, splitCandidate.getGainValue(), 0.00001);
    NumericSplitCandidate numSplitCandidate = (NumericSplitCandidate) splitCandidate;
    TreeNodeNumericCondition[] childConditions = numSplitCandidate.getChildConditions();
    assertEquals(2, childConditions.length);
    assertEquals((95.0 + 100.0) / 2.0, childConditions[0].getSplitValue(), 0.0);
    assertEquals((95.0 + 100.0) / 2.0, childConditions[1].getSplitValue(), 0.0);
    assertEquals(NumericOperator.LessThanOrEqual, childConditions[0].getNumericOperator());
    assertEquals(NumericOperator.LargerThan, childConditions[1].getNumericOperator());
    double[] childRowWeights = new double[data.length];
    System.arraycopy(rowWeights, 0, childRowWeights, 0, rowWeights.length);
    BitSet inChild = columnData.updateChildMemberships(childConditions[0], dataMemberships);
    DataMemberships childMemberships = dataMemberships.createChildMemberships(inChild);
    ClassificationPriors childTargetPriors = targetData.getDistribution(childMemberships, config);
    SplitCandidate splitCandidateChild = columnData.calcBestSplitClassification(childMemberships, childTargetPriors, targetData, rd);
    assertNotNull(splitCandidateChild);
    assertThat(splitCandidateChild, instanceOf(NumericSplitCandidate.class));
    // manually via libre office calc
    assertEquals(0.5, splitCandidateChild.getGainValue(), 0.00001);
    TreeNodeNumericCondition[] childConditions2 = ((NumericSplitCandidate) splitCandidateChild).getChildConditions();
    assertEquals(2, childConditions2.length);
    assertEquals((75.0 + 85.0) / 2.0, childConditions2[0].getSplitValue(), 0.0);
    System.arraycopy(rowWeights, 0, childRowWeights, 0, rowWeights.length);
    inChild = columnData.updateChildMemberships(childConditions[1], dataMemberships);
    childMemberships = dataMemberships.createChildMemberships(inChild);
    childTargetPriors = targetData.getDistribution(childMemberships, config);
    splitCandidateChild = columnData.calcBestSplitClassification(childMemberships, childTargetPriors, targetData, rd);
    assertNull(splitCandidateChild);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) RandomData(org.apache.commons.math.random.RandomData) TreeNodeNumericCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition) BitSet(java.util.BitSet) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) NumericMissingSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) Test(org.junit.Test)

Example 15 with SplitCandidate

use of org.knime.base.node.mine.treeensemble2.learner.SplitCandidate in project knime-core by knime.

the class TreeNumericColumnDataTest method testCalcBestSplitClassificationSplitAtEnd.

/**
 * Test splits at last possible split position - even if no change in target can be observed, see example data in
 * method body.
 * @throws Exception
 */
@Test
public void testCalcBestSplitClassificationSplitAtEnd() throws Exception {
    // Index:  1 2 3 4 5 6 7 8
    // Value:  1 1|2 2 2|3 3 3
    // Target: A A|A A A|A A B
    double[] data = asDataArray("1,1,2,2,2,3,3,3");
    String[] target = asStringArray("A,A,A,A,A,A,A,B");
    TreeEnsembleLearnerConfiguration config = createConfig();
    RandomData rd = config.createRandomData();
    Pair<TreeOrdinaryNumericColumnData, TreeTargetNominalColumnData> exampleData = exampleData(config, data, target);
    TreeNumericColumnData columnData = exampleData.getFirst();
    TreeTargetNominalColumnData targetData = exampleData.getSecond();
    double[] rowWeights = new double[data.length];
    Arrays.fill(rowWeights, 1.0);
    TreeData treeData = createTreeDataClassification(exampleData);
    IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, rd);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NumericSplitCandidate.class));
    assertTrue(splitCandidate.canColumnBeSplitFurther());
    // manually calculated
    assertEquals(/*0.21875 - 0.166666667 */
    0.05208, splitCandidate.getGainValue(), 0.001);
    NumericSplitCandidate numSplitCandidate = (NumericSplitCandidate) splitCandidate;
    TreeNodeNumericCondition[] childConditions = numSplitCandidate.getChildConditions();
    assertEquals(2, childConditions.length);
    assertEquals((2.0 + 3.0) / 2.0, childConditions[0].getSplitValue(), 0.0);
    assertEquals(NumericOperator.LessThanOrEqual, childConditions[0].getNumericOperator());
    double[] childRowWeights = new double[data.length];
    System.arraycopy(rowWeights, 0, childRowWeights, 0, rowWeights.length);
    BitSet inChild = columnData.updateChildMemberships(childConditions[0], dataMemberships);
    DataMemberships childMemberships = dataMemberships.createChildMemberships(inChild);
    ClassificationPriors childTargetPriors = targetData.getDistribution(childMemberships, config);
    SplitCandidate splitCandidateChild = columnData.calcBestSplitClassification(childMemberships, childTargetPriors, targetData, rd);
    assertNull(splitCandidateChild);
    System.arraycopy(rowWeights, 0, childRowWeights, 0, rowWeights.length);
    inChild = columnData.updateChildMemberships(childConditions[1], dataMemberships);
    childMemberships = dataMemberships.createChildMemberships(inChild);
    childTargetPriors = targetData.getDistribution(childMemberships, config);
    splitCandidateChild = columnData.calcBestSplitClassification(childMemberships, childTargetPriors, targetData, null);
    assertNull(splitCandidateChild);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) RandomData(org.apache.commons.math.random.RandomData) TreeNodeNumericCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition) BitSet(java.util.BitSet) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) NumericMissingSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) Test(org.junit.Test)

Aggregations

TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)26 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)21 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)19 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)18 RandomData (org.apache.commons.math.random.RandomData)16 Test (org.junit.Test)16 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)12 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)12 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)12 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)11 BitSet (java.util.BitSet)10 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)9 TreeNodeNominalBinaryCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition)8 NumericSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate)7 TreeNodeCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)7 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)6 NumericMissingSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate)6 TreeNodeNumericCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition)6 ColumnMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships)5 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)4