Search in sources :

Example 11 with TestDataGenerator

use of org.knime.base.node.mine.treeensemble2.data.TestDataGenerator in project knime-core by knime.

the class RootDescendantDataMembershipsTest method testCreateChildDataMemberships.

@Test
public void testCreateChildDataMemberships() {
    TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
    TestDataGenerator dataGen = new TestDataGenerator(config);
    TreeData data = dataGen.createTennisData();
    DefaultDataIndexManager indexManager = new DefaultDataIndexManager(data);
    int nrRows = data.getNrRows();
    RowSample rowSample = new DefaultRowSample(nrRows);
    RootDataMemberships rootMemberships = new RootDataMemberships(rowSample, data, indexManager);
    BitSet firstHalf = new BitSet(nrRows);
    firstHalf.set(0, nrRows / 2);
    DataMemberships firstHalfChildMemberships = rootMemberships.createChildMemberships(firstHalf);
    assertThat(firstHalfChildMemberships, instanceOf(BitSetDescendantDataMemberships.class));
    BitSetDescendantDataMemberships bitSetFirstHalfChildMemberships = (BitSetDescendantDataMemberships) firstHalfChildMemberships;
    assertEquals(firstHalf, bitSetFirstHalfChildMemberships.getBitSet());
    BitSet firstQuarter = new BitSet(nrRows);
    firstQuarter.set(0, nrRows / 4);
    DataMemberships firstQuarterGrandChild = firstHalfChildMemberships.createChildMemberships(firstQuarter);
    assertThat(firstQuarterGrandChild, instanceOf(BitSetDescendantDataMemberships.class));
    BitSetDescendantDataMemberships bitSetFirstQuarterGrandChild = (BitSetDescendantDataMemberships) firstQuarterGrandChild;
    assertEquals(firstQuarter, bitSetFirstQuarterGrandChild.getBitSet());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) DefaultRowSample(org.knime.base.node.mine.treeensemble2.sample.row.DefaultRowSample) BitSet(java.util.BitSet) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) DefaultRowSample(org.knime.base.node.mine.treeensemble2.sample.row.DefaultRowSample) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample) TestDataGenerator(org.knime.base.node.mine.treeensemble2.data.TestDataGenerator) Test(org.junit.Test)

Example 12 with TestDataGenerator

use of org.knime.base.node.mine.treeensemble2.data.TestDataGenerator in project knime-core by knime.

the class RootDescendantDataMembershipsTest method testGetColumnMemberships.

@Test
public void testGetColumnMemberships() {
    TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
    TestDataGenerator dataGen = new TestDataGenerator(config);
    TreeData data = dataGen.createTennisData();
    DefaultDataIndexManager indexManager = new DefaultDataIndexManager(data);
    int nrRows = data.getNrRows();
    RowSample rowSample = new DefaultRowSample(nrRows);
    RootDataMemberships rootMemberships = new RootDataMemberships(rowSample, data, indexManager);
    ColumnMemberships rootColMem = rootMemberships.getColumnMemberships(0);
    assertThat(rootColMem, instanceOf(IntArrayColumnMemberships.class));
    assertEquals(nrRows, rootColMem.size());
    int[] expectedOriginalIndices = new int[] { 0, 1, 7, 8, 10, 2, 6, 11, 12, 3, 4, 5, 9, 13 };
    for (int i = 0; rootColMem.next(); i++) {
        // in this case originalIndex and indexInDataMemberships are the same
        assertEquals(expectedOriginalIndices[i], rootColMem.getIndexInDataMemberships());
        assertEquals(expectedOriginalIndices[i], rootColMem.getIndexInDataMemberships());
        assertEquals(i, rootColMem.getIndexInColumn());
    }
    BitSet lastHalf = new BitSet(nrRows);
    lastHalf.set(nrRows / 2, nrRows);
    DataMemberships lastHalfChild = rootMemberships.createChildMemberships(lastHalf);
    ColumnMemberships childColMem = lastHalfChild.getColumnMemberships(0);
    assertThat(childColMem, instanceOf(DescendantColumnMemberships.class));
    assertEquals(nrRows / 2, childColMem.size());
    expectedOriginalIndices = new int[] { 7, 8, 10, 11, 12, 9, 13 };
    int[] expectedIndexInColumn = new int[] { 2, 3, 4, 7, 8, 12, 13 };
    int[] expectedIndexInDataMemberships = new int[] { 7, 8, 10, 11, 12, 9, 13 };
    for (int i = 0; childColMem.next(); i++) {
        assertEquals(expectedOriginalIndices[i], childColMem.getOriginalIndex());
        assertEquals(expectedIndexInColumn[i], childColMem.getIndexInColumn());
        assertEquals(expectedIndexInDataMemberships[i], childColMem.getIndexInDataMemberships());
    }
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) DefaultRowSample(org.knime.base.node.mine.treeensemble2.sample.row.DefaultRowSample) BitSet(java.util.BitSet) TestDataGenerator(org.knime.base.node.mine.treeensemble2.data.TestDataGenerator) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) DefaultRowSample(org.knime.base.node.mine.treeensemble2.sample.row.DefaultRowSample) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample) Test(org.junit.Test)

Example 13 with TestDataGenerator

use of org.knime.base.node.mine.treeensemble2.data.TestDataGenerator in project knime-core by knime.

the class TreeNodeNominalConditionTest method testToPMMLPredicate.

/**
 * This method tests the {@link TreeNodeNominalCondition#toPMMLPredicate()} method.
 *
 * @throws Exception
 */
@Test
public void testToPMMLPredicate() throws Exception {
    final TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    final TreeNominalColumnData col = dataGen.createNominalAttributeColumn("A,A,B,C,C,D", "testcol", 0);
    TreeNodeNominalCondition cond = new TreeNodeNominalCondition(col.getMetaData(), 3, false);
    PMMLPredicate predicate = cond.toPMMLPredicate();
    assertThat(predicate, instanceOf(PMMLSimplePredicate.class));
    PMMLSimplePredicate simplePredicate = (PMMLSimplePredicate) predicate;
    assertEquals("Wrong operator", PMMLOperator.EQUAL, simplePredicate.getOperator());
    assertEquals("Wrong split value", "D", simplePredicate.getThreshold());
    cond = new TreeNodeNominalCondition(col.getMetaData(), 0, true);
    predicate = cond.toPMMLPredicate();
    assertThat(predicate, instanceOf(PMMLCompoundPredicate.class));
    PMMLCompoundPredicate compound = (PMMLCompoundPredicate) predicate;
    assertEquals("Wrong boolean operator.", PMMLBooleanOperator.OR, compound.getBooleanOperator());
    List<PMMLPredicate> preds;
    preds = compound.getPredicates();
    assertEquals("Wrong number of predicates in compound predicate.", 2, preds.size());
    assertThat(preds.get(0), instanceOf(PMMLSimplePredicate.class));
    simplePredicate = (PMMLSimplePredicate) preds.get(0);
    assertEquals("Wrong operator", PMMLOperator.EQUAL, simplePredicate.getOperator());
    assertEquals("Wrong split value", "A", simplePredicate.getThreshold());
    assertEquals("Wrong attribute.", col.getMetaData().getAttributeName(), simplePredicate.getSplitAttribute());
    assertThat(preds.get(1), instanceOf(PMMLSimplePredicate.class));
    simplePredicate = (PMMLSimplePredicate) preds.get(1);
    assertEquals("Should be isMissing", PMMLOperator.IS_MISSING, simplePredicate.getOperator());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) PMMLSimplePredicate(org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate) PMMLPredicate(org.knime.base.node.mine.decisiontree2.PMMLPredicate) TreeNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeNominalColumnData) TestDataGenerator(org.knime.base.node.mine.treeensemble2.data.TestDataGenerator) PMMLCompoundPredicate(org.knime.base.node.mine.decisiontree2.PMMLCompoundPredicate) Test(org.junit.Test)

Example 14 with TestDataGenerator

use of org.knime.base.node.mine.treeensemble2.data.TestDataGenerator in project knime-core by knime.

the class TreeNodeNominalConditionTest method testTestCondition.

/**
 * This method tests the
 * {@link TreeNodeNominalCondition#testCondition(org.knime.base.node.mine.treeensemble2.data.PredictorRecord)}
 * method
 *
 * @throws Exception
 */
@Test
public void testTestCondition() throws Exception {
    final TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    final TreeNominalColumnData col = dataGen.createNominalAttributeColumn("A,A,B,C,C,D", "testcol", 0);
    TreeNodeNominalCondition cond = new TreeNodeNominalCondition(col.getMetaData(), 3, false);
    final Map<String, Object> map = Maps.newHashMap();
    final String colName = col.getMetaData().getAttributeName();
    map.put(colName, 0);
    final PredictorRecord record = new PredictorRecord(map);
    assertFalse("The value A was falsely accepted", cond.testCondition(record));
    map.clear();
    map.put(colName, 1);
    assertFalse("The value B was falsely accepted", cond.testCondition(record));
    map.clear();
    map.put(colName, 2);
    assertFalse("The value C was falsely accepted", cond.testCondition(record));
    map.clear();
    map.put(colName, 3);
    assertTrue("The value D was falsely rejected", cond.testCondition(record));
    map.clear();
    map.put(colName, PredictorRecord.NULL);
    assertFalse("Missing values were falsely accepted", cond.testCondition(record));
    cond = new TreeNodeNominalCondition(col.getMetaData(), 0, true);
    map.clear();
    map.put(colName, 0);
    assertTrue("The value A was falsely rejected", cond.testCondition(record));
    map.clear();
    map.put(colName, 1);
    assertFalse("The value B was falsely accepted", cond.testCondition(record));
    map.clear();
    map.put(colName, 2);
    assertFalse("The value C was falsely accepted", cond.testCondition(record));
    map.clear();
    map.put(colName, 3);
    assertFalse("The value D was falsely accepted", cond.testCondition(record));
    map.clear();
    map.put(colName, PredictorRecord.NULL);
    assertTrue("Missing values were falsely rejected", cond.testCondition(record));
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) TreeNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeNominalColumnData) TestDataGenerator(org.knime.base.node.mine.treeensemble2.data.TestDataGenerator) Test(org.junit.Test)

Example 15 with TestDataGenerator

use of org.knime.base.node.mine.treeensemble2.data.TestDataGenerator in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitCassificationBinaryTwoClassXGBoostMissingValue1.

/**
 * Tests the XGBoost missing value handling variant, where for each split it is tried which direction for missing
 * values provides the better gain.
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitCassificationBinaryTwoClassXGBoostMissingValue1() throws Exception {
    final TreeEnsembleLearnerConfiguration config = createConfig(false);
    config.setMissingValueHandling(MissingValueHandling.XGBoost);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    // check correct behavior if no missing values are encountered during split search
    Pair<TreeNominalColumnData, TreeTargetNominalColumnData> twoClassTennisData = twoClassTennisData(config);
    String dataContainingMissingsCSV = "S,?,O,R,S,R,S,?,O,?";
    final TreeNominalColumnData columnData = dataGen.createNominalAttributeColumn(dataContainingMissingsCSV, "column containing missing values", 0);
    final TreeTargetNominalColumnData target = twoClassTennisData.getSecond();
    double[] rowWeights = new double[TWO_CLASS_INDICES.length];
    Arrays.fill(rowWeights, 1.0);
    // based on the ordering in the columnData
    final int[] originalIndex = new int[] { 0, 4, 6, 2, 8, 3, 5, 1, 7, 9 };
    final int[] columnIndex = new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
    final DataMemberships dataMem = new MockDataColMem(originalIndex, columnIndex, rowWeights);
    final SplitCandidate split = columnData.calcBestSplitClassification(dataMem, target.getDistribution(rowWeights, config), target, TestDataGenerator.createRandomData());
    assertThat(split, instanceOf(NominalBinarySplitCandidate.class));
    final NominalBinarySplitCandidate nomSplit = (NominalBinarySplitCandidate) split;
    TreeNodeNominalBinaryCondition[] childConditions = nomSplit.getChildConditions();
    assertEquals("Wrong gain value.", 0.18, nomSplit.getGainValue(), 1e-8);
    final String[] conditionValues = new String[] { "S", "R" };
    assertArrayEquals("Values in nominal condition did not match", conditionValues, childConditions[0].getValues());
    assertArrayEquals("Values in nominal condition did not match", conditionValues, childConditions[1].getValues());
    assertEquals("Wrong set logic.", SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
    assertEquals("Wrong set logic.", SetLogic.IS_IN, childConditions[1].getSetLogic());
    assertTrue("Missing values are not sent to the correct child.", childConditions[0].acceptsMissings());
    assertFalse("Missing values are not sent to the correct child.", childConditions[1].acceptsMissings());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeNominalBinaryCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) Test(org.junit.Test)

Aggregations

TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)21 Test (org.junit.Test)19 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)11 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)11 TestDataGenerator (org.knime.base.node.mine.treeensemble2.data.TestDataGenerator)9 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)8 RandomData (org.apache.commons.math.random.RandomData)7 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)6 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)6 BitSet (java.util.BitSet)5 TreeNodeNominalBinaryCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition)5 TreeNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeNominalColumnData)4 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)4 PMMLCompoundPredicate (org.knime.base.node.mine.decisiontree2.PMMLCompoundPredicate)3 PMMLPredicate (org.knime.base.node.mine.decisiontree2.PMMLPredicate)3 PMMLSimplePredicate (org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate)3 PredictorRecord (org.knime.base.node.mine.treeensemble2.data.PredictorRecord)3 TreeNodeNominalCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition)3 TreeNodeNumericCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition)3 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)2