Search in sources :

Example 6 with TreeNumericColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeNumericColumnData in project knime-core by knime.

the class TreeNumericColumnDataTest method testXGBoostMissingValueHandling.

/**
 * This method tests if the conditions for child nodes are correct in case of XGBoostMissingValueHandling
 *
 * @throws Exception
 */
@Test
public void testXGBoostMissingValueHandling() throws Exception {
    TreeEnsembleLearnerConfiguration config = createConfig();
    config.setMissingValueHandling(MissingValueHandling.XGBoost);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    final RandomData rd = config.createRandomData();
    final int[] indices = new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
    final double[] weights = new double[10];
    Arrays.fill(weights, 1.0);
    final MockDataColMem dataMem = new MockDataColMem(indices, indices, weights);
    final String dataCSV = "1,2,2,3,4,5,6,7,NaN,NaN";
    final String target1CSV = "A,A,A,A,B,B,B,B,A,A";
    final String target2CSV = "A,A,A,A,B,B,B,B,B,B";
    final double expectedGain = 0.48;
    final TreeNumericColumnData col = dataGen.createNumericAttributeColumn(dataCSV, "testCol", 0);
    final TreeTargetNominalColumnData target1 = TestDataGenerator.createNominalTargetColumn(target1CSV);
    final SplitCandidate split1 = col.calcBestSplitClassification(dataMem, target1.getDistribution(weights, config), target1, rd);
    assertEquals("Wrong gain.", expectedGain, split1.getGainValue(), 1e-8);
    final TreeNodeCondition[] childConds1 = split1.getChildConditions();
    final TreeNodeNumericCondition numCondLeft1 = (TreeNodeNumericCondition) childConds1[0];
    assertEquals("Wrong split point.", 3.5, numCondLeft1.getSplitValue(), 1e-8);
    assertTrue("Missings were not sent in the correct direction.", numCondLeft1.acceptsMissings());
    final TreeNodeNumericCondition numCondRight1 = (TreeNodeNumericCondition) childConds1[1];
    assertEquals("Wrong split point.", 3.5, numCondRight1.getSplitValue(), 1e-8);
    assertFalse("Missings were not sent in the correct direction.", numCondRight1.acceptsMissings());
    final TreeTargetNominalColumnData target2 = TestDataGenerator.createNominalTargetColumn(target2CSV);
    final SplitCandidate split2 = col.calcBestSplitClassification(dataMem, target2.getDistribution(weights, config), target2, rd);
    assertEquals("Wrong gain.", expectedGain, split2.getGainValue(), 1e-8);
    final TreeNodeCondition[] childConds2 = split2.getChildConditions();
    final TreeNodeNumericCondition numCondLeft2 = (TreeNodeNumericCondition) childConds2[0];
    assertEquals("Wrong split point.", 3.5, numCondLeft2.getSplitValue(), 1e-8);
    assertFalse("Missings were not sent in the correct direction.", numCondLeft2.acceptsMissings());
    final TreeNodeNumericCondition numCondRight2 = (TreeNodeNumericCondition) childConds2[1];
    assertEquals("Wrong split point.", 3.5, numCondRight2.getSplitValue(), 1e-8);
    assertTrue("Missings were not sent in the correct direction.", numCondRight2.acceptsMissings());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeNodeNumericCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) NumericMissingSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition) Test(org.junit.Test)

Example 7 with TreeNumericColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeNumericColumnData in project knime-core by knime.

the class TreeNumericColumnDataTest method testCalcBestSplitRegression.

@Test
public void testCalcBestSplitRegression() throws InvalidSettingsException {
    String dataCSV = "1,2,3,4,5,6,7,8,9,10";
    String targetCSV = "1,5,4,4.3,6.5,6.5,4,3,3,4";
    TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(true);
    config.setNrModels(1);
    config.setDataSelectionWithReplacement(false);
    config.setUseDifferentAttributesAtEachNode(false);
    config.setDataFractionPerTree(1.0);
    config.setColumnSamplingMode(ColumnSamplingMode.None);
    TestDataGenerator dataGen = new TestDataGenerator(config);
    RandomData rd = config.createRandomData();
    TreeTargetNumericColumnData target = TestDataGenerator.createNumericTargetColumn(targetCSV);
    TreeNumericColumnData attribute = dataGen.createNumericAttributeColumn(dataCSV, "test-col", 0);
    TreeData data = new TreeData(new TreeAttributeColumnData[] { attribute }, target, TreeType.Ordinary);
    double[] weights = new double[10];
    Arrays.fill(weights, 1.0);
    DataMemberships rootMem = new RootDataMemberships(weights, data, new DefaultDataIndexManager(data));
    SplitCandidate firstSplit = attribute.calcBestSplitRegression(rootMem, target.getPriors(rootMem, config), target, rd);
    // calculated via OpenOffice calc
    assertEquals(10.885444, firstSplit.getGainValue(), 1e-5);
    TreeNodeCondition[] firstConditions = firstSplit.getChildConditions();
    assertEquals(2, firstConditions.length);
    for (int i = 0; i < firstConditions.length; i++) {
        assertThat(firstConditions[i], instanceOf(TreeNodeNumericCondition.class));
        TreeNodeNumericCondition numCond = (TreeNodeNumericCondition) firstConditions[i];
        assertEquals(1.5, numCond.getSplitValue(), 0);
    }
    // left child contains only one row therefore only look at right child
    BitSet expectedInChild = new BitSet(10);
    expectedInChild.set(1, 10);
    BitSet inChild = attribute.updateChildMemberships(firstConditions[1], rootMem);
    assertEquals(expectedInChild, inChild);
    DataMemberships childMem = rootMem.createChildMemberships(inChild);
    SplitCandidate secondSplit = attribute.calcBestSplitRegression(childMem, target.getPriors(childMem, config), target, rd);
    assertEquals(6.883555, secondSplit.getGainValue(), 1e-5);
    TreeNodeCondition[] secondConditions = secondSplit.getChildConditions();
    for (int i = 0; i < secondConditions.length; i++) {
        assertThat(secondConditions[i], instanceOf(TreeNodeNumericCondition.class));
        TreeNodeNumericCondition numCond = (TreeNodeNumericCondition) secondConditions[i];
        assertEquals(6.5, numCond.getSplitValue(), 0);
    }
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) RandomData(org.apache.commons.math.random.RandomData) TreeNodeNumericCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition) BitSet(java.util.BitSet) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) NumericMissingSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition) Test(org.junit.Test)

Example 8 with TreeNumericColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeNumericColumnData in project knime-core by knime.

the class TreeTargetNominalColumnDataTest method testGetDistribution.

/**
 * Tests the {@link TreeTargetNominalColumnData#getDistribution(DataMemberships, TreeEnsembleLearnerConfiguration)}
 * and {@link TreeTargetNominalColumnData#getDistribution(double[], TreeEnsembleLearnerConfiguration)} methods.
 * @throws InvalidSettingsException
 */
@Test
public void testGetDistribution() throws InvalidSettingsException {
    String targetCSV = "A,A,A,B,B,B,A";
    String attributeCSV = "1,2,3,4,5,6,7";
    TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
    TestDataGenerator dataGen = new TestDataGenerator(config);
    TreeTargetNominalColumnData target = TestDataGenerator.createNominalTargetColumn(targetCSV);
    TreeNumericColumnData attribute = dataGen.createNumericAttributeColumn(attributeCSV, "test-col", 0);
    TreeData data = new TreeData(new TreeAttributeColumnData[] { attribute }, target, TreeType.Ordinary);
    double[] weights = new double[7];
    Arrays.fill(weights, 1.0);
    DataMemberships rootMemberships = new RootDataMemberships(weights, data, new DefaultDataIndexManager(data));
    // Gini
    config.setSplitCriterion(SplitCriterion.Gini);
    double expectedGini = 0.4897959184;
    double[] expectedDistribution = new double[] { 4.0, 3.0 };
    ClassificationPriors giniPriorsDatMem = target.getDistribution(rootMemberships, config);
    assertEquals(expectedGini, giniPriorsDatMem.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, giniPriorsDatMem.getDistribution(), DELTA);
    ClassificationPriors giniPriorsWeights = target.getDistribution(weights, config);
    assertEquals(expectedGini, giniPriorsWeights.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, giniPriorsWeights.getDistribution(), DELTA);
    // Information Gain
    config.setSplitCriterion(SplitCriterion.InformationGain);
    double expectedEntropy = 0.985228136;
    ClassificationPriors igPriorsDatMem = target.getDistribution(rootMemberships, config);
    assertEquals(expectedEntropy, igPriorsDatMem.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, igPriorsDatMem.getDistribution(), DELTA);
    ClassificationPriors igPriorsWeights = target.getDistribution(weights, config);
    assertEquals(expectedEntropy, igPriorsWeights.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, igPriorsWeights.getDistribution(), DELTA);
    // Information Gain Ratio
    config.setSplitCriterion(SplitCriterion.InformationGainRatio);
    // prior impurity is the same as IG
    ClassificationPriors igrPriorsDatMem = target.getDistribution(rootMemberships, config);
    assertEquals(expectedEntropy, igrPriorsDatMem.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, igrPriorsDatMem.getDistribution(), DELTA);
    ClassificationPriors igrPriorsWeights = target.getDistribution(weights, config);
    assertEquals(expectedEntropy, igrPriorsWeights.getPriorImpurity(), DELTA);
    assertArrayEquals(expectedDistribution, igrPriorsWeights.getDistribution(), DELTA);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) Test(org.junit.Test)

Example 9 with TreeNumericColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeNumericColumnData in project knime-core by knime.

the class TreeNodeNumericConditionTest method testTestCondition.

/**
 * This method tests the
 * {@link TreeNodeNominalCondition#testCondition(org.knime.base.node.mine.treeensemble2.data.PredictorRecord)}
 * method.
 *
 * @throws Exception
 */
@Test
public void testTestCondition() throws Exception {
    final TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    final TreeNumericColumnData col = dataGen.createNumericAttributeColumn("1,2,3,4,4,5,6,7", "testCol", 0);
    TreeNodeNumericCondition cond = new TreeNodeNumericCondition(col.getMetaData(), 3, NumericOperator.LessThanOrEqual, false);
    final Map<String, Object> map = Maps.newHashMap();
    final String colName = col.getMetaData().getAttributeName();
    map.put(colName, 2.5);
    final PredictorRecord record = new PredictorRecord(map);
    assertTrue("2.5 was falsely rejected.", cond.testCondition(record));
    map.clear();
    map.put(colName, 3);
    assertTrue("3 was falsely rejected.", cond.testCondition(record));
    map.clear();
    map.put(colName, 4);
    assertFalse("4 was falsely accepted.", cond.testCondition(record));
    map.clear();
    map.put(colName, PredictorRecord.NULL);
    assertFalse("Missing values were falsely accepted.", cond.testCondition(record));
    cond = new TreeNodeNumericCondition(col.getMetaData(), 3, NumericOperator.LessThanOrEqual, true);
    map.clear();
    map.put(colName, 2.5);
    assertTrue("2.5 was falsely rejected.", cond.testCondition(record));
    map.clear();
    map.put(colName, 3);
    assertTrue("3 was falsely rejected.", cond.testCondition(record));
    map.clear();
    map.put(colName, 4);
    assertFalse("4 was falsely accepted.", cond.testCondition(record));
    map.clear();
    map.put(colName, PredictorRecord.NULL);
    assertTrue("Missing values were falsely rejected.", cond.testCondition(record));
    cond = new TreeNodeNumericCondition(col.getMetaData(), 4, NumericOperator.LargerThan, false);
    map.clear();
    map.put(colName, 2.5);
    assertFalse("2.5 was falsely accepted.", cond.testCondition(record));
    map.clear();
    map.put(colName, 3);
    assertFalse("3 was falsely accepted.", cond.testCondition(record));
    map.clear();
    map.put(colName, 4);
    assertFalse("4 was falsely accepted.", cond.testCondition(record));
    map.clear();
    map.put(colName, 4.01);
    assertTrue("4.01 was falsely rejected.", cond.testCondition(record));
    map.clear();
    map.put(colName, PredictorRecord.NULL);
    assertFalse("Missing values were falsely accepted.", cond.testCondition(record));
    cond = new TreeNodeNumericCondition(col.getMetaData(), 4, NumericOperator.LargerThan, true);
    map.clear();
    map.put(colName, 2.5);
    assertFalse("2.5 was falsely accepted.", cond.testCondition(record));
    map.clear();
    map.put(colName, 3);
    assertFalse("3 was falsely accepted.", cond.testCondition(record));
    map.clear();
    map.put(colName, 4.01);
    assertTrue("4 was falsely rejected.", cond.testCondition(record));
    map.clear();
    map.put(colName, PredictorRecord.NULL);
    assertTrue("Missing values were falsely rejected.", cond.testCondition(record));
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) TreeNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeNumericColumnData) TestDataGenerator(org.knime.base.node.mine.treeensemble2.data.TestDataGenerator) Test(org.junit.Test)

Example 10 with TreeNumericColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeNumericColumnData in project knime-core by knime.

the class TreeNodeNumericConditionTest method testToPMML.

/**
 * This method tests the {@link TreeNodeNumericCondition#toPMMLPredicate()} method.
 *
 * @throws Exception
 */
@Test
public void testToPMML() throws Exception {
    final TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    final TreeNumericColumnData col = dataGen.createNumericAttributeColumn("1,2,3,4,4,5,6,7", "testCol", 0);
    TreeNodeNumericCondition cond = new TreeNodeNumericCondition(col.getMetaData(), 3, NumericOperator.LessThanOrEqual, false);
    PMMLPredicate predicate = cond.toPMMLPredicate();
    assertThat(predicate, instanceOf(PMMLSimplePredicate.class));
    PMMLSimplePredicate simplePredicate = (PMMLSimplePredicate) predicate;
    assertEquals("Wrong attribute", col.getMetaData().getAttributeName(), simplePredicate.getSplitAttribute());
    assertEquals("Wrong operator", PMMLOperator.LESS_OR_EQUAL, simplePredicate.getOperator());
    assertEquals("Wrong threshold", Double.toString(3), simplePredicate.getThreshold());
    cond = new TreeNodeNumericCondition(col.getMetaData(), 4.5, NumericOperator.LargerThan, true);
    predicate = cond.toPMMLPredicate();
    assertThat(predicate, instanceOf(PMMLCompoundPredicate.class));
    PMMLCompoundPredicate compound = (PMMLCompoundPredicate) predicate;
    assertEquals("Wrong boolean operator in compound.", PMMLBooleanOperator.OR, compound.getBooleanOperator());
    List<PMMLPredicate> preds = compound.getPredicates();
    assertEquals("Wrong number of predicates in compound.", 2, preds.size());
    assertThat(preds.get(0), instanceOf(PMMLSimplePredicate.class));
    simplePredicate = (PMMLSimplePredicate) preds.get(0);
    assertEquals("Wrong attribute", col.getMetaData().getAttributeName(), simplePredicate.getSplitAttribute());
    assertEquals("Wrong operator", PMMLOperator.GREATER_THAN, simplePredicate.getOperator());
    assertEquals("Wrong threshold", Double.toString(4.5), simplePredicate.getThreshold());
    assertThat(preds.get(1), instanceOf(PMMLSimplePredicate.class));
    simplePredicate = (PMMLSimplePredicate) preds.get(1);
    assertEquals("Should be isMissing", PMMLOperator.IS_MISSING, simplePredicate.getOperator());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) PMMLSimplePredicate(org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate) PMMLPredicate(org.knime.base.node.mine.decisiontree2.PMMLPredicate) TreeNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeNumericColumnData) TestDataGenerator(org.knime.base.node.mine.treeensemble2.data.TestDataGenerator) PMMLCompoundPredicate(org.knime.base.node.mine.decisiontree2.PMMLCompoundPredicate) Test(org.junit.Test)

Aggregations

TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)10 Test (org.junit.Test)9 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)7 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)7 TreeNodeNumericCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition)7 RandomData (org.apache.commons.math.random.RandomData)6 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)6 NumericMissingSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate)6 NumericSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate)6 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)6 BitSet (java.util.BitSet)4 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)4 TestDataGenerator (org.knime.base.node.mine.treeensemble2.data.TestDataGenerator)2 TreeNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeNumericColumnData)2 TreeNodeCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)2 PMMLCompoundPredicate (org.knime.base.node.mine.decisiontree2.PMMLCompoundPredicate)1 PMMLPredicate (org.knime.base.node.mine.decisiontree2.PMMLPredicate)1 PMMLSimplePredicate (org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate)1 PredictorRecord (org.knime.base.node.mine.treeensemble2.data.PredictorRecord)1