Search in sources :

Example 6 with TreeNodeNominalCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition in project knime-core by knime.

the class TreeNodeNominalConditionTest method testTestCondition.

/**
 * This method tests the
 * {@link TreeNodeNominalCondition#testCondition(org.knime.base.node.mine.treeensemble2.data.PredictorRecord)}
 * method
 *
 * @throws Exception
 */
@Test
public void testTestCondition() throws Exception {
    final TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    final TreeNominalColumnData col = dataGen.createNominalAttributeColumn("A,A,B,C,C,D", "testcol", 0);
    TreeNodeNominalCondition cond = new TreeNodeNominalCondition(col.getMetaData(), 3, false);
    final Map<String, Object> map = Maps.newHashMap();
    final String colName = col.getMetaData().getAttributeName();
    map.put(colName, 0);
    final PredictorRecord record = new PredictorRecord(map);
    assertFalse("The value A was falsely accepted", cond.testCondition(record));
    map.clear();
    map.put(colName, 1);
    assertFalse("The value B was falsely accepted", cond.testCondition(record));
    map.clear();
    map.put(colName, 2);
    assertFalse("The value C was falsely accepted", cond.testCondition(record));
    map.clear();
    map.put(colName, 3);
    assertTrue("The value D was falsely rejected", cond.testCondition(record));
    map.clear();
    map.put(colName, PredictorRecord.NULL);
    assertFalse("Missing values were falsely accepted", cond.testCondition(record));
    cond = new TreeNodeNominalCondition(col.getMetaData(), 0, true);
    map.clear();
    map.put(colName, 0);
    assertTrue("The value A was falsely rejected", cond.testCondition(record));
    map.clear();
    map.put(colName, 1);
    assertFalse("The value B was falsely accepted", cond.testCondition(record));
    map.clear();
    map.put(colName, 2);
    assertFalse("The value C was falsely accepted", cond.testCondition(record));
    map.clear();
    map.put(colName, 3);
    assertFalse("The value D was falsely accepted", cond.testCondition(record));
    map.clear();
    map.put(colName, PredictorRecord.NULL);
    assertTrue("Missing values were falsely rejected", cond.testCondition(record));
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) TreeNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeNominalColumnData) TestDataGenerator(org.knime.base.node.mine.treeensemble2.data.TestDataGenerator) Test(org.junit.Test)

Example 7 with TreeNodeNominalCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition in project knime-core by knime.

the class LiteralConditionParser method handleSimplePredicate.

private TreeNodeColumnCondition handleSimplePredicate(final SimplePredicate simplePred, final boolean acceptsMissings) {
    String field = simplePred.getField();
    if (m_metaDataMapper.isNominal(field)) {
        NominalAttributeColumnHelper colHelper = m_metaDataMapper.getNominalColumnHelper(field);
        return new TreeNodeNominalCondition(colHelper.getMetaData(), colHelper.getRepresentation(simplePred.getValue()).getAssignedInteger(), acceptsMissings);
    } else {
        TreeNumericColumnMetaData metaData = m_metaDataMapper.getNumericColumnHelper(field).getMetaData();
        double value = Double.parseDouble(simplePred.getValue());
        return new TreeNodeNumericCondition(metaData, value, parseNumericOperator(simplePred.getOperator()), acceptsMissings);
    }
}
Also used : TreeNumericColumnMetaData(org.knime.base.node.mine.treeensemble2.data.TreeNumericColumnMetaData) TreeNodeNumericCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition) TreeNodeNominalCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition)

Example 8 with TreeNodeNominalCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitRegressionMultiway.

/**
 * Tests the method
 * {@link TreeNominalColumnData#calcBestSplitRegression(DataMemberships, RegressionPriors, TreeTargetNumericColumnData, RandomData)}
 * using multiway splits.
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitRegressionMultiway() throws Exception {
    TreeEnsembleLearnerConfiguration config = createConfig(true);
    config.setUseBinaryNominalSplits(false);
    Pair<TreeNominalColumnData, TreeTargetNumericColumnData> tennisDataRegression = tennisDataRegression(config);
    TreeNominalColumnData columnData = tennisDataRegression.getFirst();
    TreeTargetNumericColumnData targetData = tennisDataRegression.getSecond();
    TreeData treeData = createTreeDataRegression(tennisDataRegression);
    double[] rowWeights = new double[SMALL_COLUMN_DATA.length];
    Arrays.fill(rowWeights, 1.0);
    IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    RegressionPriors priors = targetData.getPriors(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitRegression(dataMemberships, priors, targetData, null);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NominalMultiwaySplitCandidate.class));
    assertFalse(splitCandidate.canColumnBeSplitFurther());
    assertEquals(36.9643, splitCandidate.getGainValue(), 0.0001);
    NominalMultiwaySplitCandidate multiwaySplitCandidate = (NominalMultiwaySplitCandidate) splitCandidate;
    TreeNodeNominalCondition[] childConditions = multiwaySplitCandidate.getChildConditions();
    assertEquals(3, childConditions.length);
    assertEquals("S", childConditions[0].getValue());
    assertEquals("O", childConditions[1].getValue());
    assertEquals("R", childConditions[2].getValue());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeNominalCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) Test(org.junit.Test)

Example 9 with TreeNodeNominalCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitClassificationMultiwayXGBoostMissingValueHandling.

/**
 * This method tests the XGBoost missing value handling for classification in case of multiway splits.
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitClassificationMultiwayXGBoostMissingValueHandling() throws Exception {
    final TreeEnsembleLearnerConfiguration config = createConfig(false);
    config.setUseBinaryNominalSplits(false);
    config.setMissingValueHandling(MissingValueHandling.XGBoost);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    final RandomData rd = config.createRandomData();
    // test the case that there are no missing values in the training data
    final String noMissingCSV = "a, a, a, b, b, b, b, c, c";
    final String noMissingTarget = "A, B, B, C, C, C, B, A, B";
    TreeNominalColumnData dataCol = dataGen.createNominalAttributeColumn(noMissingCSV, "noMissings", 0);
    TreeTargetNominalColumnData targetCol = TestDataGenerator.createNominalTargetColumn(noMissingTarget);
    DataMemberships dataMem = createMockDataMemberships(targetCol.getNrRows());
    SplitCandidate split = dataCol.calcBestSplitClassification(dataMem, targetCol.getDistribution(dataMem, config), targetCol, rd);
    assertNotNull("There is a possible split.", split);
    assertEquals("Incorrect gain.", 0.216, split.getGainValue(), 1e-3);
    assertThat(split, instanceOf(NominalMultiwaySplitCandidate.class));
    NominalMultiwaySplitCandidate nomSplit = (NominalMultiwaySplitCandidate) split;
    assertTrue("No missing values in the column.", nomSplit.getMissedRows().isEmpty());
    TreeNodeNominalCondition[] conditions = nomSplit.getChildConditions();
    assertEquals("Wrong number of child conditions.", 3, conditions.length);
    assertEquals("Wrong value in child condition.", "a", conditions[0].getValue());
    assertEquals("Wrong value in child condition.", "b", conditions[1].getValue());
    assertEquals("Wrong value in child condition.", "c", conditions[2].getValue());
    assertFalse("Missing values should be sent to the majority child (i.e. b)", conditions[0].acceptsMissings());
    assertTrue("Missing values should be sent to the majority child (i.e. b)", conditions[1].acceptsMissings());
    assertFalse("Missing values should be sent to the majority child (i.e. b)", conditions[2].acceptsMissings());
    // test the case that there are missing values in the training data
    final String missingCSV = "a, a, a, b, b, b, b, c, c, ?";
    final String missingTarget = "A, B, B, C, C, C, B, A, B, C";
    dataCol = dataGen.createNominalAttributeColumn(missingCSV, "missings", 0);
    targetCol = TestDataGenerator.createNominalTargetColumn(missingTarget);
    dataMem = createMockDataMemberships(targetCol.getNrRows());
    split = dataCol.calcBestSplitClassification(dataMem, targetCol.getDistribution(dataMem, config), targetCol, rd);
    assertNotNull("There is a possible split.", split);
    assertEquals("Incorrect gain.", 0.2467, split.getGainValue(), 1e-3);
    assertThat(split, instanceOf(NominalMultiwaySplitCandidate.class));
    nomSplit = (NominalMultiwaySplitCandidate) split;
    assertTrue("Split should handle missing values.", nomSplit.getMissedRows().isEmpty());
    conditions = nomSplit.getChildConditions();
    assertEquals("Wrong number of child conditions.", 3, conditions.length);
    assertEquals("Wrong value in child condition.", "a", conditions[0].getValue());
    assertEquals("Wrong value in child condition.", "b", conditions[1].getValue());
    assertEquals("Wrong value in child condition.", "c", conditions[2].getValue());
    assertFalse("Missing values should be sent to b", conditions[0].acceptsMissings());
    assertTrue("Missing values should be sent to b", conditions[1].acceptsMissings());
    assertFalse("Missing values should be sent to b", conditions[2].acceptsMissings());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeNodeNominalCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) Test(org.junit.Test)

Example 10 with TreeNodeNominalCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition in project knime-core by knime.

the class TreeNominalColumnDataTest method testUpdateChildMemberships.

/**
 * Tests the method
 * {@link TreeNominalColumnData#updateChildMemberships(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition, DataMemberships)}
 * .
 *
 * @throws Exception
 */
@Test
public void testUpdateChildMemberships() throws Exception {
    // in this case it doesn't matter if we use regression or classification (as well as binary and multiway splits)
    final TreeEnsembleLearnerConfiguration config = createConfig(true);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    final String dataCSV = "A, A, A, A, B, B, B, C, C, C, ?, ?";
    TreeNominalColumnData col = dataGen.createNominalAttributeColumn(dataCSV, "test-col", 0);
    final int[] indices = new int[12];
    final double[] weights = new double[indices.length];
    for (int i = 0; i < indices.length; i++) {
        indices[i] = i;
        weights[i] = 1.0;
    }
    final DataMemberships dataMem = new MockDataColMem(indices, indices, weights);
    TreeNodeNominalBinaryCondition binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(2), true, false);
    BitSet expected = new BitSet(12);
    BitSet inChild = col.updateChildMemberships(binCond, dataMem);
    expected.set(4, 7);
    assertEquals("The produced BitSet is incorrect.", expected, inChild);
    binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(2), true, true);
    expected.clear();
    expected.set(4, 7);
    expected.set(10, 12);
    inChild = col.updateChildMemberships(binCond, dataMem);
    assertEquals("The produced BitSet is incorrect.", expected, inChild);
    binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(2), false, false);
    expected.clear();
    expected.set(0, 4);
    expected.set(7, 10);
    inChild = col.updateChildMemberships(binCond, dataMem);
    assertEquals("The produced BitSet is incorrect.", expected, inChild);
    binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(2), false, true);
    expected.clear();
    expected.set(0, 4);
    expected.set(7, 12);
    inChild = col.updateChildMemberships(binCond, dataMem);
    assertEquals("The produced BitSet is incorrect.", expected, inChild);
    binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(5), true, false);
    expected.clear();
    expected.set(0, 4);
    expected.set(7, 10);
    inChild = col.updateChildMemberships(binCond, dataMem);
    assertEquals("The produced BitSet is incorrect.", expected, inChild);
    binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(5), true, true);
    expected.clear();
    expected.set(0, 4);
    expected.set(7, 12);
    inChild = col.updateChildMemberships(binCond, dataMem);
    assertEquals("The produced BitSet is incorrect.", expected, inChild);
    TreeNodeNominalCondition multiCond = new TreeNodeNominalCondition(col.getMetaData(), 0, false);
    expected.clear();
    expected.set(0, 4);
    inChild = col.updateChildMemberships(multiCond, dataMem);
    assertEquals("The produced BitSet is incorrect.", expected, inChild);
    multiCond = new TreeNodeNominalCondition(col.getMetaData(), 0, true);
    expected.clear();
    expected.set(0, 4);
    expected.set(10, 12);
    inChild = col.updateChildMemberships(multiCond, dataMem);
    assertEquals("The produced BitSet is incorrect.", expected, inChild);
    multiCond = new TreeNodeNominalCondition(col.getMetaData(), 2, false);
    expected.clear();
    expected.set(7, 10);
    inChild = col.updateChildMemberships(multiCond, dataMem);
    assertEquals("The produced BitSet is incorrect.", expected, inChild);
    multiCond = new TreeNodeNominalCondition(col.getMetaData(), 2, true);
    expected.clear();
    expected.set(7, 12);
    inChild = col.updateChildMemberships(multiCond, dataMem);
    assertEquals("The produced BitSet is incorrect.", expected, inChild);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) TreeNodeNominalCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition) TreeNodeNominalBinaryCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition) BitSet(java.util.BitSet) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) Test(org.junit.Test)

Aggregations

Test (org.junit.Test)8 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)8 TreeNodeNominalCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition)7 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)5 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)5 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)4 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)4 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)4 TestDataGenerator (org.knime.base.node.mine.treeensemble2.data.TestDataGenerator)3 BitSet (java.util.BitSet)2 RandomData (org.apache.commons.math.random.RandomData)2 PredictorRecord (org.knime.base.node.mine.treeensemble2.data.PredictorRecord)2 TreeNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeNominalColumnData)2 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)2 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)2 ArrayList (java.util.ArrayList)1 PMMLCompoundPredicate (org.knime.base.node.mine.decisiontree2.PMMLCompoundPredicate)1 PMMLPredicate (org.knime.base.node.mine.decisiontree2.PMMLPredicate)1 PMMLSimplePredicate (org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate)1 NominalValueRepresentation (org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation)1