Search in sources :

Example 6 with TreeAttributeColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData in project knime-core by knime.

the class TreeLearnerClassification method findBestSplitsClassification.

/**
 * Returns a list of SplitCandidates sorted (descending) by their gain
 *
 * @param currentDepth
 * @param rowSampleWeights
 * @param treeNodeSignature
 * @param targetPriors
 * @param forbiddenColumnSet
 * @param membershipController
 * @return
 */
private SplitCandidate[] findBestSplitsClassification(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) {
    final TreeData data = getData();
    final RandomData rd = getRandomData();
    // final ColumnSampleStrategy colSamplingStrategy = getColSamplingStrategy();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final int maxLevels = config.getMaxLevels();
    if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
        return null;
    }
    final int minNodeSize = config.getMinNodeSize();
    if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
        if (targetPriors.getNrRecords() < minNodeSize) {
            return null;
        }
    }
    final double priorImpurity = targetPriors.getPriorImpurity();
    if (priorImpurity < TreeColumnData.EPSILON) {
        return null;
    }
    final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
    SplitCandidate splitCandidate = null;
    if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
        final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
        // TODO discuss whether this option makes sense with surrogates
        return new SplitCandidate[] { rootColumn.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd) };
    }
    double bestGainValue = 0.0;
    final Comparator<SplitCandidate> comp = new Comparator<SplitCandidate>() {

        @Override
        public int compare(final SplitCandidate o1, final SplitCandidate o2) {
            int compareDouble = -Double.compare(o1.getGainValue(), o2.getGainValue());
            return compareDouble;
        }
    };
    ArrayList<SplitCandidate> candidates = new ArrayList<SplitCandidate>(columnSample.getNumCols());
    for (TreeAttributeColumnData col : columnSample) {
        if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
            continue;
        }
        SplitCandidate currentColSplit = col.calcBestSplitClassification(dataMemberships, targetPriors, targetColumn, rd);
        if (currentColSplit != null) {
            candidates.add(currentColSplit);
        }
    }
    if (candidates.isEmpty()) {
        return null;
    }
    candidates.sort(comp);
    return candidates.toArray(new SplitCandidate[candidates.size()]);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) ArrayList(java.util.ArrayList) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeTargetNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData) Comparator(java.util.Comparator)

Example 7 with TreeAttributeColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData in project knime-core by knime.

the class TreeNumericColumnDataTest method testCalcBestSplitRegression.

@Test
public void testCalcBestSplitRegression() throws InvalidSettingsException {
    String dataCSV = "1,2,3,4,5,6,7,8,9,10";
    String targetCSV = "1,5,4,4.3,6.5,6.5,4,3,3,4";
    TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(true);
    config.setNrModels(1);
    config.setDataSelectionWithReplacement(false);
    config.setUseDifferentAttributesAtEachNode(false);
    config.setDataFractionPerTree(1.0);
    config.setColumnSamplingMode(ColumnSamplingMode.None);
    TestDataGenerator dataGen = new TestDataGenerator(config);
    RandomData rd = config.createRandomData();
    TreeTargetNumericColumnData target = TestDataGenerator.createNumericTargetColumn(targetCSV);
    TreeNumericColumnData attribute = dataGen.createNumericAttributeColumn(dataCSV, "test-col", 0);
    TreeData data = new TreeData(new TreeAttributeColumnData[] { attribute }, target, TreeType.Ordinary);
    double[] weights = new double[10];
    Arrays.fill(weights, 1.0);
    DataMemberships rootMem = new RootDataMemberships(weights, data, new DefaultDataIndexManager(data));
    SplitCandidate firstSplit = attribute.calcBestSplitRegression(rootMem, target.getPriors(rootMem, config), target, rd);
    // calculated via OpenOffice calc
    assertEquals(10.885444, firstSplit.getGainValue(), 1e-5);
    TreeNodeCondition[] firstConditions = firstSplit.getChildConditions();
    assertEquals(2, firstConditions.length);
    for (int i = 0; i < firstConditions.length; i++) {
        assertThat(firstConditions[i], instanceOf(TreeNodeNumericCondition.class));
        TreeNodeNumericCondition numCond = (TreeNodeNumericCondition) firstConditions[i];
        assertEquals(1.5, numCond.getSplitValue(), 0);
    }
    // left child contains only one row therefore only look at right child
    BitSet expectedInChild = new BitSet(10);
    expectedInChild.set(1, 10);
    BitSet inChild = attribute.updateChildMemberships(firstConditions[1], rootMem);
    assertEquals(expectedInChild, inChild);
    DataMemberships childMem = rootMem.createChildMemberships(inChild);
    SplitCandidate secondSplit = attribute.calcBestSplitRegression(childMem, target.getPriors(childMem, config), target, rd);
    assertEquals(6.883555, secondSplit.getGainValue(), 1e-5);
    TreeNodeCondition[] secondConditions = secondSplit.getChildConditions();
    for (int i = 0; i < secondConditions.length; i++) {
        assertThat(secondConditions[i], instanceOf(TreeNodeNumericCondition.class));
        TreeNodeNumericCondition numCond = (TreeNodeNumericCondition) secondConditions[i];
        assertEquals(6.5, numCond.getSplitValue(), 0);
    }
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) RandomData(org.apache.commons.math.random.RandomData) TreeNodeNumericCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition) BitSet(java.util.BitSet) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) NumericMissingSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition) Test(org.junit.Test)

Example 8 with TreeAttributeColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData in project knime-core by knime.

the class TreeTargetNumericColumnDataTest method testGetPriors.

/**
 * Tests the {@link TreeTargetNumericColumnData#getPriors(DataMemberships, TreeEnsembleLearnerConfiguration)} and
 * {@link TreeTargetNumericColumnData#getPriors(double[], TreeEnsembleLearnerConfiguration)} methods.
 */
@Test
public void testGetPriors() {
    String targetCSV = "1,4,3,5,6,7,8,12,22,1";
    // irrelevant but necessary to build TreeDataObject
    String someAttributeCSV = "A,B,A,B,A,A,B,A,A,B";
    TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(true);
    TestDataGenerator dataGen = new TestDataGenerator(config);
    TreeTargetNumericColumnData target = TestDataGenerator.createNumericTargetColumn(targetCSV);
    TreeNominalColumnData attribute = dataGen.createNominalAttributeColumn(someAttributeCSV, "test-col", 0);
    TreeData data = new TreeData(new TreeAttributeColumnData[] { attribute }, target, TreeType.Ordinary);
    double[] weights = new double[10];
    Arrays.fill(weights, 1.0);
    DataMemberships rootMem = new RootDataMemberships(weights, data, new DefaultDataIndexManager(data));
    RegressionPriors datMemPriors = target.getPriors(rootMem, config);
    assertEquals(6.9, datMemPriors.getMean(), DELTA);
    assertEquals(69, datMemPriors.getYSum(), DELTA);
    assertEquals(352.9, datMemPriors.getSumSquaredDeviation(), DELTA);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) Test(org.junit.Test)

Example 9 with TreeAttributeColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData in project knime-core by knime.

the class TreeTargetNominalColumnDataTest method testGetDistribution.

/**
 * Tests the {@link TreeTargetNominalColumnData#getDistribution(DataMemberships, TreeEnsembleLearnerConfiguration)}
 * and {@link TreeTargetNominalColumnData#getDistribution(double[], TreeEnsembleLearnerConfiguration)} methods.
 * @throws InvalidSettingsException
 */
@Test
public void testGetDistribution() throws InvalidSettingsException {
    String targetCSV = "A,A,A,B,B,B,A";
    String attributeCSV = "1,2,3,4,5,6,7";
    TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
    TestDataGenerator dataGen = new TestDataGenerator(config);
    TreeTargetNominalColumnData target = TestDataGenerator.createNominalTargetColumn(targetCSV);
    TreeNumericColumnData attribute = dataGen.createNumericAttributeColumn(attributeCSV, "test-col", 0);
    TreeData data = new TreeData(new TreeAttributeColumnData[] { attribute }, target, TreeType.Ordinary);
    double[] weights = new double[7];
    Arrays.fill(weights, 1.0);
    DataMemberships rootMemberships = new RootDataMemberships(weights, data, new DefaultDataIndexManager(data));
    // Gini
    config.setSplitCriterion(SplitCriterion.Gini);
    double expectedGini = 0.4897959184;
    double[] expectedDistribution = new double[] { 4.0, 3.0 };
    ClassificationPriors giniPriorsDatMem = target.getDistribution(rootMemberships, config);
    assertEquals(expectedGini, giniPriorsDatMem.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, giniPriorsDatMem.getDistribution(), DELTA);
    ClassificationPriors giniPriorsWeights = target.getDistribution(weights, config);
    assertEquals(expectedGini, giniPriorsWeights.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, giniPriorsWeights.getDistribution(), DELTA);
    // Information Gain
    config.setSplitCriterion(SplitCriterion.InformationGain);
    double expectedEntropy = 0.985228136;
    ClassificationPriors igPriorsDatMem = target.getDistribution(rootMemberships, config);
    assertEquals(expectedEntropy, igPriorsDatMem.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, igPriorsDatMem.getDistribution(), DELTA);
    ClassificationPriors igPriorsWeights = target.getDistribution(weights, config);
    assertEquals(expectedEntropy, igPriorsWeights.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, igPriorsWeights.getDistribution(), DELTA);
    // Information Gain Ratio
    config.setSplitCriterion(SplitCriterion.InformationGainRatio);
    // prior impurity is the same as IG
    ClassificationPriors igrPriorsDatMem = target.getDistribution(rootMemberships, config);
    assertEquals(expectedEntropy, igrPriorsDatMem.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, igrPriorsDatMem.getDistribution(), DELTA);
    ClassificationPriors igrPriorsWeights = target.getDistribution(weights, config);
    assertEquals(expectedEntropy, igrPriorsWeights.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, igrPriorsWeights.getDistribution(), DELTA);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) Test(org.junit.Test)

Example 10 with TreeAttributeColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData in project knime-core by knime.

the class SubsetColumnSampleTest method testIterator.

@Test
public void testIterator() throws Exception {
    final TreeData data = createTreeData();
    int[] colIndices = new int[] { 1, 3, 5 };
    SubsetColumnSample sample = new SubsetColumnSample(data, colIndices);
    int i = 0;
    for (final TreeAttributeColumnData col : sample) {
        assertEquals("Wrong column returned.", data.getColumns()[colIndices[i++]], col);
    }
}
Also used : TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) Test(org.junit.Test)

Aggregations

TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)15 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)10 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)9 BitSet (java.util.BitSet)6 RandomData (org.apache.commons.math.random.RandomData)6 Test (org.junit.Test)6 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)6 TreeNodeCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)6 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)5 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)4 ArrayList (java.util.ArrayList)3 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)3 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)3 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)3 ColumnSample (org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample)3 Comparator (java.util.Comparator)2 ClassificationPriors (org.knime.base.node.mine.treeensemble2.data.ClassificationPriors)2 TreeNodeSurrogateOnlyDefDirCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeSurrogateOnlyDefDirCondition)2 HashMap (java.util.HashMap)1 PredictorRecord (org.knime.base.node.mine.treeensemble2.data.PredictorRecord)1