Search in sources :

Example 6 with TreeNodeNominalBinaryCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition in project knime-core by knime.

the class TreeNodeNominalBinaryConditionTest method testTestCondition.

/**
 * This method tests the
 * {@link TreeNodeNominalBinaryCondition#testCondition(org.knime.base.node.mine.treeensemble2.data.PredictorRecord)}
 * method.
 *
 * @throws Exception
 */
@Test
public void testTestCondition() throws Exception {
    final TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    final TreeNominalColumnData col = dataGen.createNominalAttributeColumn("A,A,B,C,C,D", "testcol", 0);
    TreeNodeNominalBinaryCondition cond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(1), true, false);
    final Map<String, Object> map = Maps.newHashMap();
    final String colName = col.getMetaData().getAttributeName();
    map.put(colName, 0);
    PredictorRecord record = new PredictorRecord(map);
    assertTrue("The value A was not accepted but should have been.", cond.testCondition(record));
    map.clear();
    map.put(colName, 1);
    assertFalse("The value B was falsely accepted", cond.testCondition(record));
    map.clear();
    map.put(colName, 2);
    assertFalse("The value C was falsely accepted", cond.testCondition(record));
    map.clear();
    map.put(colName, 3);
    assertFalse("The value D was falsely accepted", cond.testCondition(record));
    map.clear();
    map.put(colName, PredictorRecord.NULL);
    assertFalse("The condition falsely accepted missing values", cond.testCondition(record));
    cond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(5), true, true);
    map.clear();
    map.put(colName, 0);
    assertTrue("The value A was falsely rejected.", cond.testCondition(record));
    map.clear();
    map.put(colName, 2);
    assertTrue("The value C was falsely rejected.", cond.testCondition(record));
    map.clear();
    map.put(colName, PredictorRecord.NULL);
    assertTrue("Missing values were falsely rejected.", cond.testCondition(record));
    map.clear();
    map.put(colName, 1);
    assertFalse("The value B was falsely accepted.", cond.testCondition(record));
    map.clear();
    map.put(colName, 3);
    assertFalse("The value B was falsely accepted.", cond.testCondition(record));
    cond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(5), false, true);
    map.clear();
    map.put(colName, 0);
    assertFalse("The value A was falsely accepted.", cond.testCondition(record));
    map.clear();
    map.put(colName, 2);
    assertFalse("The value C was falsely accepted.", cond.testCondition(record));
    map.clear();
    map.put(colName, PredictorRecord.NULL);
    assertTrue("Missing values were falsely rejected.", cond.testCondition(record));
    map.clear();
    map.put(colName, 1);
    assertTrue("The value B was falsely rejected.", cond.testCondition(record));
    map.clear();
    map.put(colName, 3);
    assertTrue("The value D was falsely rejected.", cond.testCondition(record));
    cond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(5), false, false);
    map.clear();
    map.put(colName, 0);
    assertFalse("The value A was falsely accepted.", cond.testCondition(record));
    map.clear();
    map.put(colName, 2);
    assertFalse("The value C was falsely accepted.", cond.testCondition(record));
    map.clear();
    map.put(colName, PredictorRecord.NULL);
    assertFalse("Missing values were falsely accepted.", cond.testCondition(record));
    map.clear();
    map.put(colName, 1);
    assertTrue("The value B was falsely rejected.", cond.testCondition(record));
    map.clear();
    map.put(colName, 3);
    assertTrue("The value D was falsely rejected.", cond.testCondition(record));
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) TreeNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeNominalColumnData) TestDataGenerator(org.knime.base.node.mine.treeensemble2.data.TestDataGenerator) Test(org.junit.Test)

Example 7 with TreeNodeNominalBinaryCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitCassificationBinaryTwoClassXGBoostMissingValue1.

/**
 * Tests the XGBoost missing value handling variant, where for each split it is tried which direction for missing
 * values provides the better gain.
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitCassificationBinaryTwoClassXGBoostMissingValue1() throws Exception {
    final TreeEnsembleLearnerConfiguration config = createConfig(false);
    config.setMissingValueHandling(MissingValueHandling.XGBoost);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    // check correct behavior if no missing values are encountered during split search
    Pair<TreeNominalColumnData, TreeTargetNominalColumnData> twoClassTennisData = twoClassTennisData(config);
    String dataContainingMissingsCSV = "S,?,O,R,S,R,S,?,O,?";
    final TreeNominalColumnData columnData = dataGen.createNominalAttributeColumn(dataContainingMissingsCSV, "column containing missing values", 0);
    final TreeTargetNominalColumnData target = twoClassTennisData.getSecond();
    double[] rowWeights = new double[TWO_CLASS_INDICES.length];
    Arrays.fill(rowWeights, 1.0);
    // based on the ordering in the columnData
    final int[] originalIndex = new int[] { 0, 4, 6, 2, 8, 3, 5, 1, 7, 9 };
    final int[] columnIndex = new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
    final DataMemberships dataMem = new MockDataColMem(originalIndex, columnIndex, rowWeights);
    final SplitCandidate split = columnData.calcBestSplitClassification(dataMem, target.getDistribution(rowWeights, config), target, TestDataGenerator.createRandomData());
    assertThat(split, instanceOf(NominalBinarySplitCandidate.class));
    final NominalBinarySplitCandidate nomSplit = (NominalBinarySplitCandidate) split;
    TreeNodeNominalBinaryCondition[] childConditions = nomSplit.getChildConditions();
    assertEquals("Wrong gain value.", 0.18, nomSplit.getGainValue(), 1e-8);
    final String[] conditionValues = new String[] { "S", "R" };
    assertArrayEquals("Values in nominal condition did not match", conditionValues, childConditions[0].getValues());
    assertArrayEquals("Values in nominal condition did not match", conditionValues, childConditions[1].getValues());
    assertEquals("Wrong set logic.", SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
    assertEquals("Wrong set logic.", SetLogic.IS_IN, childConditions[1].getSetLogic());
    assertTrue("Missing values are not sent to the correct child.", childConditions[0].acceptsMissings());
    assertFalse("Missing values are not sent to the correct child.", childConditions[1].acceptsMissings());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeNominalBinaryCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) Test(org.junit.Test)

Example 8 with TreeNodeNominalBinaryCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitClassificationBinaryTwoClass.

/**
 * Tests the method
 * {@link TreeNominalColumnData#calcBestSplitClassification(DataMemberships, ClassificationPriors, TreeTargetNominalColumnData, RandomData)}
 * in case of a two class problem.
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitClassificationBinaryTwoClass() throws Exception {
    TreeEnsembleLearnerConfiguration config = createConfig(false);
    config.setMissingValueHandling(MissingValueHandling.Surrogate);
    Pair<TreeNominalColumnData, TreeTargetNominalColumnData> twoClassTennisData = twoClassTennisData(config);
    TreeNominalColumnData columnData = twoClassTennisData.getFirst();
    TreeTargetNominalColumnData targetData = twoClassTennisData.getSecond();
    TreeData twoClassTennisTreeData = twoClassTennisTreeData(config);
    IDataIndexManager indexManager = new DefaultDataIndexManager(twoClassTennisTreeData);
    assertEquals(SplitCriterion.Gini, config.getSplitCriterion());
    double[] rowWeights = new double[TWO_CLASS_INDICES.length];
    Arrays.fill(rowWeights, 1.0);
    // DataMemberships dataMemberships = TestDataGenerator.createMockDataMemberships(TWO_CLASS_INDICES.length);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, twoClassTennisTreeData, indexManager);
    ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, null);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NominalBinarySplitCandidate.class));
    assertTrue(splitCandidate.canColumnBeSplitFurther());
    // manually via open office calc
    assertEquals(0.1371428, splitCandidate.getGainValue(), 0.00001);
    NominalBinarySplitCandidate binSplitCandidate = (NominalBinarySplitCandidate) splitCandidate;
    TreeNodeNominalBinaryCondition[] childConditions = binSplitCandidate.getChildConditions();
    assertEquals(2, childConditions.length);
    assertArrayEquals(new String[] { "R" }, childConditions[0].getValues());
    assertArrayEquals(new String[] { "R" }, childConditions[1].getValues());
    assertEquals(SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
    assertEquals(SetLogic.IS_IN, childConditions[1].getSetLogic());
    assertFalse(childConditions[0].acceptsMissings());
    assertFalse(childConditions[1].acceptsMissings());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeNominalBinaryCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) Test(org.junit.Test)

Example 9 with TreeNodeNominalBinaryCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitClassificationBinaryPCAXGBoostMissingValueHandling.

/**
 * Tests the XGBoost missing value handling in the case of binary splits calculated with the pca method (multiple classes)
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitClassificationBinaryPCAXGBoostMissingValueHandling() throws Exception {
    final TreeEnsembleLearnerConfiguration config = createConfig(false);
    config.setMissingValueHandling(MissingValueHandling.XGBoost);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    final RandomData rd = config.createRandomData();
    // test the case that there are no missing values in the training data
    final String noMissingCSV = "a, a, a, b, b, b, b, c, c";
    final String noMissingTarget = "A, B, B, C, C, C, B, A, B";
    TreeNominalColumnData dataCol = dataGen.createNominalAttributeColumn(noMissingCSV, "noMissings", 0);
    TreeTargetNominalColumnData targetCol = TestDataGenerator.createNominalTargetColumn(noMissingTarget);
    DataMemberships dataMem = createMockDataMemberships(targetCol.getNrRows());
    SplitCandidate split = dataCol.calcBestSplitClassification(dataMem, targetCol.getDistribution(dataMem, config), targetCol, rd);
    assertNotNull("There is a possible split.", split);
    assertEquals("Incorrect gain.", 0.2086, split.getGainValue(), 1e-3);
    assertThat(split, instanceOf(NominalBinarySplitCandidate.class));
    NominalBinarySplitCandidate nomSplit = (NominalBinarySplitCandidate) split;
    assertTrue("No missing values in the column.", nomSplit.getMissedRows().isEmpty());
    TreeNodeNominalBinaryCondition[] conditions = nomSplit.getChildConditions();
    assertEquals("A binary split must have 2 child conditions.", 2, conditions.length);
    String[] values = new String[] { "a", "c" };
    assertArrayEquals("Wrong values in child condition.", values, conditions[0].getValues());
    assertArrayEquals("Wrong values in child condition.", values, conditions[1].getValues());
    assertEquals("Wrong set logic.", SetLogic.IS_NOT_IN, conditions[0].getSetLogic());
    assertEquals("Wrong set logic.", SetLogic.IS_IN, conditions[1].getSetLogic());
    assertFalse("Missing values should be sent to the majority child (i.e. right)", conditions[0].acceptsMissings());
    assertTrue("Missing values should be sent to the majority child (i.e. right)", conditions[1].acceptsMissings());
    // test the case that there are missing values in the training data
    final String missingCSV = "a, a, a, b, b, b, b, c, c, ?";
    final String missingTarget = "A, B, B, C, C, C, B, A, B, C";
    dataCol = dataGen.createNominalAttributeColumn(missingCSV, "missings", 0);
    targetCol = TestDataGenerator.createNominalTargetColumn(missingTarget);
    dataMem = createMockDataMemberships(targetCol.getNrRows());
    split = dataCol.calcBestSplitClassification(dataMem, targetCol.getDistribution(dataMem, config), targetCol, rd);
    assertNotNull("There is a possible split.", split);
    assertEquals("Incorrect gain.", 0.24, split.getGainValue(), 1e-3);
    assertThat(split, instanceOf(NominalBinarySplitCandidate.class));
    nomSplit = (NominalBinarySplitCandidate) split;
    assertTrue("Split should handle missing values.", nomSplit.getMissedRows().isEmpty());
    conditions = nomSplit.getChildConditions();
    assertEquals("Wrong number of child conditions.", 2, conditions.length);
    assertArrayEquals("Wrong values in child condition.", values, conditions[0].getValues());
    assertArrayEquals("Wrong values in child condition.", values, conditions[1].getValues());
    assertEquals("Wrong set logic.", SetLogic.IS_NOT_IN, conditions[0].getSetLogic());
    assertEquals("Wrong set logic.", SetLogic.IS_IN, conditions[1].getSetLogic());
    assertTrue("Missing values should be sent to left child", conditions[0].acceptsMissings());
    assertFalse("Missing values should be sent to left child", conditions[1].acceptsMissings());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeNodeNominalBinaryCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) Test(org.junit.Test)

Example 10 with TreeNodeNominalBinaryCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitRegressionBinary.

/**
 * Tests the method
 * {@link TreeNominalColumnData#calcBestSplitRegression(DataMemberships, RegressionPriors, TreeTargetNumericColumnData, RandomData)}
 * using binary splits
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitRegressionBinary() throws Exception {
    TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(true);
    Pair<TreeNominalColumnData, TreeTargetNumericColumnData> tennisDataRegression = tennisDataRegression(config);
    TreeNominalColumnData columnData = tennisDataRegression.getFirst();
    TreeTargetNumericColumnData targetData = tennisDataRegression.getSecond();
    TreeData treeData = createTreeDataRegression(tennisDataRegression);
    double[] rowWeights = new double[SMALL_COLUMN_DATA.length];
    Arrays.fill(rowWeights, 1.0);
    IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    RegressionPriors priors = targetData.getPriors(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitRegression(dataMemberships, priors, targetData, null);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NominalBinarySplitCandidate.class));
    assertTrue(splitCandidate.canColumnBeSplitFurther());
    assertEquals(32.9143, splitCandidate.getGainValue(), 0.0001);
    NominalBinarySplitCandidate binarySplitCandidate = (NominalBinarySplitCandidate) splitCandidate;
    TreeNodeNominalBinaryCondition[] childConditions = binarySplitCandidate.getChildConditions();
    assertEquals(2, childConditions.length);
    assertArrayEquals(new String[] { "R" }, childConditions[0].getValues());
    assertArrayEquals(new String[] { "R" }, childConditions[1].getValues());
    assertEquals(SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
    assertEquals(SetLogic.IS_IN, childConditions[1].getSetLogic());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeNominalBinaryCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) Test(org.junit.Test)

Aggregations

TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)11 Test (org.junit.Test)10 TreeNodeNominalBinaryCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition)10 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)9 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)9 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)8 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)8 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)8 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)5 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)5 BitSet (java.util.BitSet)3 RandomData (org.apache.commons.math.random.RandomData)3 TestDataGenerator (org.knime.base.node.mine.treeensemble2.data.TestDataGenerator)2 TreeNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeNominalColumnData)2 PMMLCompoundPredicate (org.knime.base.node.mine.decisiontree2.PMMLCompoundPredicate)1 PMMLPredicate (org.knime.base.node.mine.decisiontree2.PMMLPredicate)1 PMMLSimplePredicate (org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate)1 PMMLSimpleSetPredicate (org.knime.base.node.mine.decisiontree2.PMMLSimpleSetPredicate)1 PredictorRecord (org.knime.base.node.mine.treeensemble2.data.PredictorRecord)1 TreeNominalColumnMetaData (org.knime.base.node.mine.treeensemble2.data.TreeNominalColumnMetaData)1