Search in sources :

Example 16 with TreeTargetNominalColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData in project knime-core by knime.

the class TreeNumericColumnDataTest method testCalcBestSplitClassificationMissingValStrategy1.

/**
 * This test is outdated and will likely be removed soon.
 *
 * @throws Exception
 */
// @Test
public void testCalcBestSplitClassificationMissingValStrategy1() throws Exception {
    TreeEnsembleLearnerConfiguration config = createConfig();
    final double[] data = asDataArray("1, 2, 3, 4, 5, 6, 7, NaN, NaN, NaN");
    final String[] target = asStringArray("Y, Y, Y, Y, N, N, N, Y, Y, Y");
    Pair<TreeOrdinaryNumericColumnData, TreeTargetNominalColumnData> exampleData = exampleData(config, data, target);
    double[] rowWeights = new double[data.length];
    Arrays.fill(rowWeights, 1.0);
    RandomData rd = config.createRandomData();
    TreeNumericColumnData columnData = exampleData.getFirst();
    TreeTargetNominalColumnData targetData = exampleData.getSecond();
    TreeData treeData = createTreeDataClassification(exampleData);
    IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, rd);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NumericMissingSplitCandidate.class));
    assertTrue(splitCandidate.canColumnBeSplitFurther());
    assertEquals(0.42, splitCandidate.getGainValue(), 0.0001);
    TreeNodeNumericCondition[] childConditions = ((NumericMissingSplitCandidate) splitCandidate).getChildConditions();
    assertEquals(2, childConditions.length);
    assertEquals(NumericOperator.LessThanOrEqualOrMissing, childConditions[0].getNumericOperator());
    assertEquals(NumericOperator.LargerThan, childConditions[1].getNumericOperator());
    assertEquals(4.5, childConditions[0].getSplitValue(), 0.0);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) RandomData(org.apache.commons.math.random.RandomData) TreeNodeNumericCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) NumericMissingSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate) NumericMissingSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)

Example 17 with TreeTargetNominalColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData in project knime-core by knime.

the class TreeNumericColumnDataTest method testCalcBestSplitClassificationSplitAtStart.

/**
 * Test splits at last possible split position - even if no change in target can be observed, see example data in
 * method body.
 * @throws Exception
 */
@Test
public void testCalcBestSplitClassificationSplitAtStart() throws Exception {
    // Index:  1 2 3 4 5 6 7
    // Value:  1 1 1|2 2|3 3
    // Target: A A A|A A|A B
    double[] data = asDataArray("1,1,1,2,2,3,3");
    String[] target = asStringArray("A,A,A,A,B,A,B");
    TreeEnsembleLearnerConfiguration config = createConfig();
    Pair<TreeOrdinaryNumericColumnData, TreeTargetNominalColumnData> exampleData = exampleData(config, data, target);
    TreeNumericColumnData columnData = exampleData.getFirst();
    TreeTargetNominalColumnData targetData = exampleData.getSecond();
    double[] rowWeights = new double[data.length];
    Arrays.fill(rowWeights, 1.0);
    TreeData treeData = createTreeDataClassification(exampleData);
    IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
    RandomData rd = config.createRandomData();
    SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, rd);
    double gain = (1.0 - Math.pow(5.0 / 7.0, 2.0) - Math.pow(2.0 / 7.0, 2.0)) - 0.0 - 4.0 / 7.0 * (1.0 - Math.pow(2.0 / 4.0, 2.0) - Math.pow(2.0 / 4.0, 2.0));
    // manually calculated
    assertEquals(gain, splitCandidate.getGainValue(), 0.000001);
    NumericSplitCandidate numSplitCandidate = (NumericSplitCandidate) splitCandidate;
    TreeNodeNumericCondition[] childConditions = numSplitCandidate.getChildConditions();
    assertEquals(2, childConditions.length);
    assertEquals((1.0 + 2.0) / 2.0, childConditions[0].getSplitValue(), 0.0);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) RandomData(org.apache.commons.math.random.RandomData) TreeNodeNumericCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) NumericMissingSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) Test(org.junit.Test)

Example 18 with TreeTargetNominalColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData in project knime-core by knime.

the class TreeNumericColumnDataTest method testXGBoostMissingValueHandling.

/**
 * This method tests if the conditions for child nodes are correct in case of XGBoostMissingValueHandling
 *
 * @throws Exception
 */
@Test
public void testXGBoostMissingValueHandling() throws Exception {
    TreeEnsembleLearnerConfiguration config = createConfig();
    config.setMissingValueHandling(MissingValueHandling.XGBoost);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    final RandomData rd = config.createRandomData();
    final int[] indices = new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
    final double[] weights = new double[10];
    Arrays.fill(weights, 1.0);
    final MockDataColMem dataMem = new MockDataColMem(indices, indices, weights);
    final String dataCSV = "1,2,2,3,4,5,6,7,NaN,NaN";
    final String target1CSV = "A,A,A,A,B,B,B,B,A,A";
    final String target2CSV = "A,A,A,A,B,B,B,B,B,B";
    final double expectedGain = 0.48;
    final TreeNumericColumnData col = dataGen.createNumericAttributeColumn(dataCSV, "testCol", 0);
    final TreeTargetNominalColumnData target1 = TestDataGenerator.createNominalTargetColumn(target1CSV);
    final SplitCandidate split1 = col.calcBestSplitClassification(dataMem, target1.getDistribution(weights, config), target1, rd);
    assertEquals("Wrong gain.", expectedGain, split1.getGainValue(), 1e-8);
    final TreeNodeCondition[] childConds1 = split1.getChildConditions();
    final TreeNodeNumericCondition numCondLeft1 = (TreeNodeNumericCondition) childConds1[0];
    assertEquals("Wrong split point.", 3.5, numCondLeft1.getSplitValue(), 1e-8);
    assertTrue("Missings were not sent in the correct direction.", numCondLeft1.acceptsMissings());
    final TreeNodeNumericCondition numCondRight1 = (TreeNodeNumericCondition) childConds1[1];
    assertEquals("Wrong split point.", 3.5, numCondRight1.getSplitValue(), 1e-8);
    assertFalse("Missings were not sent in the correct direction.", numCondRight1.acceptsMissings());
    final TreeTargetNominalColumnData target2 = TestDataGenerator.createNominalTargetColumn(target2CSV);
    final SplitCandidate split2 = col.calcBestSplitClassification(dataMem, target2.getDistribution(weights, config), target2, rd);
    assertEquals("Wrong gain.", expectedGain, split2.getGainValue(), 1e-8);
    final TreeNodeCondition[] childConds2 = split2.getChildConditions();
    final TreeNodeNumericCondition numCondLeft2 = (TreeNodeNumericCondition) childConds2[0];
    assertEquals("Wrong split point.", 3.5, numCondLeft2.getSplitValue(), 1e-8);
    assertFalse("Missings were not sent in the correct direction.", numCondLeft2.acceptsMissings());
    final TreeNodeNumericCondition numCondRight2 = (TreeNodeNumericCondition) childConds2[1];
    assertEquals("Wrong split point.", 3.5, numCondRight2.getSplitValue(), 1e-8);
    assertTrue("Missings were not sent in the correct direction.", numCondRight2.acceptsMissings());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeNodeNumericCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) NumericMissingSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition) Test(org.junit.Test)

Example 19 with TreeTargetNominalColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitCassificationBinaryTwoClassXGBoostMissingValue1.

/**
 * Tests the XGBoost missing value handling variant, where for each split it is tried which direction for missing
 * values provides the better gain.
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitCassificationBinaryTwoClassXGBoostMissingValue1() throws Exception {
    final TreeEnsembleLearnerConfiguration config = createConfig(false);
    config.setMissingValueHandling(MissingValueHandling.XGBoost);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    // check correct behavior if no missing values are encountered during split search
    Pair<TreeNominalColumnData, TreeTargetNominalColumnData> twoClassTennisData = twoClassTennisData(config);
    String dataContainingMissingsCSV = "S,?,O,R,S,R,S,?,O,?";
    final TreeNominalColumnData columnData = dataGen.createNominalAttributeColumn(dataContainingMissingsCSV, "column containing missing values", 0);
    final TreeTargetNominalColumnData target = twoClassTennisData.getSecond();
    double[] rowWeights = new double[TWO_CLASS_INDICES.length];
    Arrays.fill(rowWeights, 1.0);
    // based on the ordering in the columnData
    final int[] originalIndex = new int[] { 0, 4, 6, 2, 8, 3, 5, 1, 7, 9 };
    final int[] columnIndex = new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
    final DataMemberships dataMem = new MockDataColMem(originalIndex, columnIndex, rowWeights);
    final SplitCandidate split = columnData.calcBestSplitClassification(dataMem, target.getDistribution(rowWeights, config), target, TestDataGenerator.createRandomData());
    assertThat(split, instanceOf(NominalBinarySplitCandidate.class));
    final NominalBinarySplitCandidate nomSplit = (NominalBinarySplitCandidate) split;
    TreeNodeNominalBinaryCondition[] childConditions = nomSplit.getChildConditions();
    assertEquals("Wrong gain value.", 0.18, nomSplit.getGainValue(), 1e-8);
    final String[] conditionValues = new String[] { "S", "R" };
    assertArrayEquals("Values in nominal condition did not match", conditionValues, childConditions[0].getValues());
    assertArrayEquals("Values in nominal condition did not match", conditionValues, childConditions[1].getValues());
    assertEquals("Wrong set logic.", SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
    assertEquals("Wrong set logic.", SetLogic.IS_IN, childConditions[1].getSetLogic());
    assertTrue("Missing values are not sent to the correct child.", childConditions[0].acceptsMissings());
    assertFalse("Missing values are not sent to the correct child.", childConditions[1].acceptsMissings());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeNominalBinaryCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) Test(org.junit.Test)

Example 20 with TreeTargetNominalColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitClassificationBinaryTwoClass.

/**
 * Tests the method
 * {@link TreeNominalColumnData#calcBestSplitClassification(DataMemberships, ClassificationPriors, TreeTargetNominalColumnData, RandomData)}
 * in case of a two class problem.
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitClassificationBinaryTwoClass() throws Exception {
    TreeEnsembleLearnerConfiguration config = createConfig(false);
    config.setMissingValueHandling(MissingValueHandling.Surrogate);
    Pair<TreeNominalColumnData, TreeTargetNominalColumnData> twoClassTennisData = twoClassTennisData(config);
    TreeNominalColumnData columnData = twoClassTennisData.getFirst();
    TreeTargetNominalColumnData targetData = twoClassTennisData.getSecond();
    TreeData twoClassTennisTreeData = twoClassTennisTreeData(config);
    IDataIndexManager indexManager = new DefaultDataIndexManager(twoClassTennisTreeData);
    assertEquals(SplitCriterion.Gini, config.getSplitCriterion());
    double[] rowWeights = new double[TWO_CLASS_INDICES.length];
    Arrays.fill(rowWeights, 1.0);
    // DataMemberships dataMemberships = TestDataGenerator.createMockDataMemberships(TWO_CLASS_INDICES.length);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, twoClassTennisTreeData, indexManager);
    ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, null);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NominalBinarySplitCandidate.class));
    assertTrue(splitCandidate.canColumnBeSplitFurther());
    // manually via open office calc
    assertEquals(0.1371428, splitCandidate.getGainValue(), 0.00001);
    NominalBinarySplitCandidate binSplitCandidate = (NominalBinarySplitCandidate) splitCandidate;
    TreeNodeNominalBinaryCondition[] childConditions = binSplitCandidate.getChildConditions();
    assertEquals(2, childConditions.length);
    assertArrayEquals(new String[] { "R" }, childConditions[0].getValues());
    assertArrayEquals(new String[] { "R" }, childConditions[1].getValues());
    assertEquals(SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
    assertEquals(SetLogic.IS_IN, childConditions[1].getSetLogic());
    assertFalse(childConditions[0].acceptsMissings());
    assertFalse(childConditions[1].acceptsMissings());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeNominalBinaryCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) Test(org.junit.Test)

Aggregations

TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)23 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)16 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)16 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)14 RandomData (org.apache.commons.math.random.RandomData)13 Test (org.junit.Test)13 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)12 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)11 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)10 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)9 BitSet (java.util.BitSet)8 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)7 TreeNodeNominalBinaryCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition)7 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)6 NumericSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate)6 NumericMissingSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate)5 TreeNodeNumericCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition)5 BigInteger (java.math.BigInteger)4 ArrayList (java.util.ArrayList)4 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)4