Search in sources :

Example 1 with NominalMultiwaySplitCandidate

use of org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitClassificationMultiWay.

/**
 * Tests the method
 * {@link TreeNominalColumnData#calcBestSplitClassification(DataMemberships, ClassificationPriors, TreeTargetNominalColumnData, RandomData)}
 * using multiway splits
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitClassificationMultiWay() throws Exception {
    TreeEnsembleLearnerConfiguration config = createConfig(false);
    config.setUseBinaryNominalSplits(false);
    Pair<TreeNominalColumnData, TreeTargetNominalColumnData> tennisData = tennisData(config);
    TreeNominalColumnData columnData = tennisData.getFirst();
    TreeTargetNominalColumnData targetData = tennisData.getSecond();
    TreeData treeData = createTreeData(tennisData);
    assertEquals(SplitCriterion.Gini, config.getSplitCriterion());
    double[] rowWeights = new double[SMALL_COLUMN_DATA.length];
    Arrays.fill(rowWeights, 1.0);
    IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, null);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NominalMultiwaySplitCandidate.class));
    assertFalse(splitCandidate.canColumnBeSplitFurther());
    // manually via libre office calc
    assertEquals(0.0744897959, splitCandidate.getGainValue(), 0.00001);
    NominalMultiwaySplitCandidate multiWaySplitCandidate = (NominalMultiwaySplitCandidate) splitCandidate;
    TreeNodeNominalCondition[] childConditions = multiWaySplitCandidate.getChildConditions();
    assertEquals(3, childConditions.length);
    assertEquals("S", childConditions[0].getValue());
    assertEquals("O", childConditions[1].getValue());
    assertEquals("R", childConditions[2].getValue());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeNominalCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) Test(org.junit.Test)

Example 2 with NominalMultiwaySplitCandidate

use of org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitRegressionMultiwayXGBoostMissingValueHandling.

/**
 * This method tests the XGBoost missing value handling in case of a regression task and multiway splits.
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitRegressionMultiwayXGBoostMissingValueHandling() throws Exception {
    final TreeEnsembleLearnerConfiguration config = createConfig(true);
    config.setMissingValueHandling(MissingValueHandling.XGBoost);
    config.setUseBinaryNominalSplits(false);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    final String noMissingCSV = "A, A, A, B, B, B, B, C, C";
    final String noMissingsTarget = "1, 2, 2, 7, 6, 5, 2, 3, 1";
    TreeNominalColumnData dataCol = dataGen.createNominalAttributeColumn(noMissingCSV, "noMissings", 0);
    TreeTargetNumericColumnData targetCol = TestDataGenerator.createNumericTargetColumn(noMissingsTarget);
    double[] weights = new double[9];
    Arrays.fill(weights, 1.0);
    int[] indices = new int[9];
    for (int i = 0; i < indices.length; i++) {
        indices[i] = i;
    }
    final RandomData rd = config.createRandomData();
    DataMemberships dataMemberships = new MockDataColMem(indices, indices, weights);
    // first test the case that there are no missing values during training (we still need to provide a missing value direction for prediction)
    SplitCandidate split = dataCol.calcBestSplitRegression(dataMemberships, targetCol.getPriors(weights, config), targetCol, rd);
    assertNotNull("SplitCandidate may not be null", split);
    assertThat(split, instanceOf(NominalMultiwaySplitCandidate.class));
    assertEquals("Wrong gain.", 22.888888, split.getGainValue(), 1e-5);
    assertTrue("No missing values in dataCol therefore the missedRows BitSet must be empty.", split.getMissedRows().isEmpty());
    NominalMultiwaySplitCandidate nomSplit = (NominalMultiwaySplitCandidate) split;
    TreeNodeNominalCondition[] conditions = nomSplit.getChildConditions();
    assertEquals("3 nominal values therefore there must be 3 children.", 3, conditions.length);
    assertEquals("Wrong value.", "A", conditions[0].getValue());
    assertEquals("Wrong value.", "B", conditions[1].getValue());
    assertEquals("Wrong value.", "C", conditions[2].getValue());
    assertFalse("Missings should go with majority", conditions[0].acceptsMissings());
    assertTrue("Missings should go with majority", conditions[1].acceptsMissings());
    assertFalse("Missings should go with majority", conditions[2].acceptsMissings());
    // test the case that there are missing values during training
    final String missingCSV = "A, A, A, B, B, B, B, C, C, ?";
    final String missingTarget = "1, 2, 2, 7, 6, 5, 2, 3, 1, 8";
    dataCol = dataGen.createNominalAttributeColumn(missingCSV, "missing", 0);
    targetCol = TestDataGenerator.createNumericTargetColumn(missingTarget);
    weights = new double[10];
    Arrays.fill(weights, 1.0);
    indices = new int[10];
    for (int i = 0; i < indices.length; i++) {
        indices[i] = i;
    }
    dataMemberships = new MockDataColMem(indices, indices, weights);
    split = dataCol.calcBestSplitRegression(dataMemberships, targetCol.getPriors(weights, config), targetCol, rd);
    assertNotNull("SplitCandidate may not be null.", split);
    assertThat(split, instanceOf(NominalMultiwaySplitCandidate.class));
    // assertEquals("Wrong gain.", 36.233333333, split.getGainValue(), 1e-5);
    assertTrue("Conditions should handle missing values therefore the missedRows BitSet must be empty.", split.getMissedRows().isEmpty());
    nomSplit = (NominalMultiwaySplitCandidate) split;
    conditions = nomSplit.getChildConditions();
    assertEquals("3 values (not counting missing values) therefore there must be 3 children.", 3, conditions.length);
    assertEquals("Wrong value.", "A", conditions[0].getValue());
    assertEquals("Wrong value.", "B", conditions[1].getValue());
    assertEquals("Wrong value.", "C", conditions[2].getValue());
    assertFalse("Missings should go with majority", conditions[0].acceptsMissings());
    assertTrue("Missings should go with majority", conditions[1].acceptsMissings());
    assertFalse("Missings should go with majority", conditions[2].acceptsMissings());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeNodeNominalCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) Test(org.junit.Test)

Example 3 with NominalMultiwaySplitCandidate

use of org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitRegressionMultiway.

/**
 * Tests the method
 * {@link TreeNominalColumnData#calcBestSplitRegression(DataMemberships, RegressionPriors, TreeTargetNumericColumnData, RandomData)}
 * using multiway splits.
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitRegressionMultiway() throws Exception {
    TreeEnsembleLearnerConfiguration config = createConfig(true);
    config.setUseBinaryNominalSplits(false);
    Pair<TreeNominalColumnData, TreeTargetNumericColumnData> tennisDataRegression = tennisDataRegression(config);
    TreeNominalColumnData columnData = tennisDataRegression.getFirst();
    TreeTargetNumericColumnData targetData = tennisDataRegression.getSecond();
    TreeData treeData = createTreeDataRegression(tennisDataRegression);
    double[] rowWeights = new double[SMALL_COLUMN_DATA.length];
    Arrays.fill(rowWeights, 1.0);
    IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    RegressionPriors priors = targetData.getPriors(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitRegression(dataMemberships, priors, targetData, null);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NominalMultiwaySplitCandidate.class));
    assertFalse(splitCandidate.canColumnBeSplitFurther());
    assertEquals(36.9643, splitCandidate.getGainValue(), 0.0001);
    NominalMultiwaySplitCandidate multiwaySplitCandidate = (NominalMultiwaySplitCandidate) splitCandidate;
    TreeNodeNominalCondition[] childConditions = multiwaySplitCandidate.getChildConditions();
    assertEquals(3, childConditions.length);
    assertEquals("S", childConditions[0].getValue());
    assertEquals("O", childConditions[1].getValue());
    assertEquals("R", childConditions[2].getValue());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeNominalCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) Test(org.junit.Test)

Example 4 with NominalMultiwaySplitCandidate

use of org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitClassificationMultiwayXGBoostMissingValueHandling.

/**
 * This method tests the XGBoost missing value handling for classification in case of multiway splits.
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitClassificationMultiwayXGBoostMissingValueHandling() throws Exception {
    final TreeEnsembleLearnerConfiguration config = createConfig(false);
    config.setUseBinaryNominalSplits(false);
    config.setMissingValueHandling(MissingValueHandling.XGBoost);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    final RandomData rd = config.createRandomData();
    // test the case that there are no missing values in the training data
    final String noMissingCSV = "a, a, a, b, b, b, b, c, c";
    final String noMissingTarget = "A, B, B, C, C, C, B, A, B";
    TreeNominalColumnData dataCol = dataGen.createNominalAttributeColumn(noMissingCSV, "noMissings", 0);
    TreeTargetNominalColumnData targetCol = TestDataGenerator.createNominalTargetColumn(noMissingTarget);
    DataMemberships dataMem = createMockDataMemberships(targetCol.getNrRows());
    SplitCandidate split = dataCol.calcBestSplitClassification(dataMem, targetCol.getDistribution(dataMem, config), targetCol, rd);
    assertNotNull("There is a possible split.", split);
    assertEquals("Incorrect gain.", 0.216, split.getGainValue(), 1e-3);
    assertThat(split, instanceOf(NominalMultiwaySplitCandidate.class));
    NominalMultiwaySplitCandidate nomSplit = (NominalMultiwaySplitCandidate) split;
    assertTrue("No missing values in the column.", nomSplit.getMissedRows().isEmpty());
    TreeNodeNominalCondition[] conditions = nomSplit.getChildConditions();
    assertEquals("Wrong number of child conditions.", 3, conditions.length);
    assertEquals("Wrong value in child condition.", "a", conditions[0].getValue());
    assertEquals("Wrong value in child condition.", "b", conditions[1].getValue());
    assertEquals("Wrong value in child condition.", "c", conditions[2].getValue());
    assertFalse("Missing values should be sent to the majority child (i.e. b)", conditions[0].acceptsMissings());
    assertTrue("Missing values should be sent to the majority child (i.e. b)", conditions[1].acceptsMissings());
    assertFalse("Missing values should be sent to the majority child (i.e. b)", conditions[2].acceptsMissings());
    // test the case that there are missing values in the training data
    final String missingCSV = "a, a, a, b, b, b, b, c, c, ?";
    final String missingTarget = "A, B, B, C, C, C, B, A, B, C";
    dataCol = dataGen.createNominalAttributeColumn(missingCSV, "missings", 0);
    targetCol = TestDataGenerator.createNominalTargetColumn(missingTarget);
    dataMem = createMockDataMemberships(targetCol.getNrRows());
    split = dataCol.calcBestSplitClassification(dataMem, targetCol.getDistribution(dataMem, config), targetCol, rd);
    assertNotNull("There is a possible split.", split);
    assertEquals("Incorrect gain.", 0.2467, split.getGainValue(), 1e-3);
    assertThat(split, instanceOf(NominalMultiwaySplitCandidate.class));
    nomSplit = (NominalMultiwaySplitCandidate) split;
    assertTrue("Split should handle missing values.", nomSplit.getMissedRows().isEmpty());
    conditions = nomSplit.getChildConditions();
    assertEquals("Wrong number of child conditions.", 3, conditions.length);
    assertEquals("Wrong value in child condition.", "a", conditions[0].getValue());
    assertEquals("Wrong value in child condition.", "b", conditions[1].getValue());
    assertEquals("Wrong value in child condition.", "c", conditions[2].getValue());
    assertFalse("Missing values should be sent to b", conditions[0].acceptsMissings());
    assertTrue("Missing values should be sent to b", conditions[1].acceptsMissings());
    assertFalse("Missing values should be sent to b", conditions[2].acceptsMissings());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeNodeNominalCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) Test(org.junit.Test)

Aggregations

Test (org.junit.Test)4 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)4 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)4 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)4 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)4 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)4 TreeNodeNominalCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition)4 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)4 RandomData (org.apache.commons.math.random.RandomData)2 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)2 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)2