Search in sources :

Example 16 with TreeNominalColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeNominalColumnData in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitClassificationBinary.

/**
 * Tests the method
 * {@link TreeNominalColumnData#calcBestSplitClassification(DataMemberships, ClassificationPriors, TreeTargetNominalColumnData, RandomData)}
 * using binary splits.
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitClassificationBinary() throws Exception {
    final TreeEnsembleLearnerConfiguration config = createConfig(false);
    Pair<TreeNominalColumnData, TreeTargetNominalColumnData> tennisData = tennisData(config);
    TreeNominalColumnData columnData = tennisData.getFirst();
    TreeTargetNominalColumnData targetData = tennisData.getSecond();
    assertEquals(SplitCriterion.Gini, config.getSplitCriterion());
    double[] rowWeights = new double[SMALL_COLUMN_DATA.length];
    Arrays.fill(rowWeights, 1.0);
    TreeData tennisTreeData = tennisTreeData(config);
    IDataIndexManager indexManager = new DefaultDataIndexManager(tennisTreeData);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, tennisTreeData, indexManager);
    ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, null);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NominalBinarySplitCandidate.class));
    assertTrue(splitCandidate.canColumnBeSplitFurther());
    // manually via libre office calc
    assertEquals(0.0689342404, splitCandidate.getGainValue(), 0.00001);
    NominalBinarySplitCandidate binSplitCandidate = (NominalBinarySplitCandidate) splitCandidate;
    TreeNodeNominalBinaryCondition[] childConditions = binSplitCandidate.getChildConditions();
    assertEquals(2, childConditions.length);
    assertArrayEquals(new String[] { "R" }, childConditions[0].getValues());
    assertArrayEquals(new String[] { "R" }, childConditions[1].getValues());
    assertEquals(SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
    assertEquals(SetLogic.IS_IN, childConditions[1].getSetLogic());
    BitSet inChild = columnData.updateChildMemberships(childConditions[0], dataMemberships);
    DataMemberships child1Memberships = dataMemberships.createChildMemberships(inChild);
    ClassificationPriors childTargetPriors = targetData.getDistribution(child1Memberships, config);
    SplitCandidate splitCandidateChild = columnData.calcBestSplitClassification(child1Memberships, childTargetPriors, targetData, null);
    assertNotNull(splitCandidateChild);
    assertThat(splitCandidateChild, instanceOf(NominalBinarySplitCandidate.class));
    // manually via libre office calc
    assertEquals(0.0086419753, splitCandidateChild.getGainValue(), 0.00001);
    inChild = columnData.updateChildMemberships(childConditions[1], dataMemberships);
    DataMemberships child2Memberships = dataMemberships.createChildMemberships(inChild);
    childTargetPriors = targetData.getDistribution(child2Memberships, config);
    splitCandidateChild = columnData.calcBestSplitClassification(child2Memberships, childTargetPriors, targetData, null);
    assertNull(splitCandidateChild);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) BitSet(java.util.BitSet) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeNominalBinaryCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) Test(org.junit.Test)

Example 17 with TreeNominalColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeNominalColumnData in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitClassificationMultiwayXGBoostMissingValueHandling.

/**
 * This method tests the XGBoost missing value handling for classification in case of multiway splits.
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitClassificationMultiwayXGBoostMissingValueHandling() throws Exception {
    final TreeEnsembleLearnerConfiguration config = createConfig(false);
    config.setUseBinaryNominalSplits(false);
    config.setMissingValueHandling(MissingValueHandling.XGBoost);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    final RandomData rd = config.createRandomData();
    // test the case that there are no missing values in the training data
    final String noMissingCSV = "a, a, a, b, b, b, b, c, c";
    final String noMissingTarget = "A, B, B, C, C, C, B, A, B";
    TreeNominalColumnData dataCol = dataGen.createNominalAttributeColumn(noMissingCSV, "noMissings", 0);
    TreeTargetNominalColumnData targetCol = TestDataGenerator.createNominalTargetColumn(noMissingTarget);
    DataMemberships dataMem = createMockDataMemberships(targetCol.getNrRows());
    SplitCandidate split = dataCol.calcBestSplitClassification(dataMem, targetCol.getDistribution(dataMem, config), targetCol, rd);
    assertNotNull("There is a possible split.", split);
    assertEquals("Incorrect gain.", 0.216, split.getGainValue(), 1e-3);
    assertThat(split, instanceOf(NominalMultiwaySplitCandidate.class));
    NominalMultiwaySplitCandidate nomSplit = (NominalMultiwaySplitCandidate) split;
    assertTrue("No missing values in the column.", nomSplit.getMissedRows().isEmpty());
    TreeNodeNominalCondition[] conditions = nomSplit.getChildConditions();
    assertEquals("Wrong number of child conditions.", 3, conditions.length);
    assertEquals("Wrong value in child condition.", "a", conditions[0].getValue());
    assertEquals("Wrong value in child condition.", "b", conditions[1].getValue());
    assertEquals("Wrong value in child condition.", "c", conditions[2].getValue());
    assertFalse("Missing values should be sent to the majority child (i.e. b)", conditions[0].acceptsMissings());
    assertTrue("Missing values should be sent to the majority child (i.e. b)", conditions[1].acceptsMissings());
    assertFalse("Missing values should be sent to the majority child (i.e. b)", conditions[2].acceptsMissings());
    // test the case that there are missing values in the training data
    final String missingCSV = "a, a, a, b, b, b, b, c, c, ?";
    final String missingTarget = "A, B, B, C, C, C, B, A, B, C";
    dataCol = dataGen.createNominalAttributeColumn(missingCSV, "missings", 0);
    targetCol = TestDataGenerator.createNominalTargetColumn(missingTarget);
    dataMem = createMockDataMemberships(targetCol.getNrRows());
    split = dataCol.calcBestSplitClassification(dataMem, targetCol.getDistribution(dataMem, config), targetCol, rd);
    assertNotNull("There is a possible split.", split);
    assertEquals("Incorrect gain.", 0.2467, split.getGainValue(), 1e-3);
    assertThat(split, instanceOf(NominalMultiwaySplitCandidate.class));
    nomSplit = (NominalMultiwaySplitCandidate) split;
    assertTrue("Split should handle missing values.", nomSplit.getMissedRows().isEmpty());
    conditions = nomSplit.getChildConditions();
    assertEquals("Wrong number of child conditions.", 3, conditions.length);
    assertEquals("Wrong value in child condition.", "a", conditions[0].getValue());
    assertEquals("Wrong value in child condition.", "b", conditions[1].getValue());
    assertEquals("Wrong value in child condition.", "c", conditions[2].getValue());
    assertFalse("Missing values should be sent to b", conditions[0].acceptsMissings());
    assertTrue("Missing values should be sent to b", conditions[1].acceptsMissings());
    assertFalse("Missing values should be sent to b", conditions[2].acceptsMissings());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeNodeNominalCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) Test(org.junit.Test)

Example 18 with TreeNominalColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeNominalColumnData in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitClassificationBinaryPCA.

/**
 * Tests the method
 * {@link TreeNominalColumnData#calcBestSplitClassification(DataMemberships, ClassificationPriors, TreeTargetNominalColumnData, RandomData)}
 * using binary splits. In this test case the data has more than two classes and the used algorithm is therefore PCA
 * based.
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitClassificationBinaryPCA() throws Exception {
    TreeEnsembleLearnerConfiguration config = createConfig(false);
    Pair<TreeNominalColumnData, TreeTargetNominalColumnData> pcaData = createPCATestData(config);
    TreeNominalColumnData columnData = pcaData.getFirst();
    TreeTargetNominalColumnData targetData = pcaData.getSecond();
    TreeData treeData = createTreeData(pcaData);
    assertEquals(SplitCriterion.Gini, config.getSplitCriterion());
    double[] rowWeights = new double[targetData.getNrRows()];
    Arrays.fill(rowWeights, 1.0);
    IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, null);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NominalBinarySplitCandidate.class));
    assertTrue(splitCandidate.canColumnBeSplitFurther());
    assertEquals(0.0659, splitCandidate.getGainValue(), 0.0001);
    NominalBinarySplitCandidate binarySplitCandidate = (NominalBinarySplitCandidate) splitCandidate;
    TreeNodeNominalBinaryCondition[] childConditions = binarySplitCandidate.getChildConditions();
    assertEquals(2, childConditions.length);
    assertArrayEquals(new String[] { "E" }, childConditions[0].getValues());
    assertArrayEquals(new String[] { "E" }, childConditions[1].getValues());
    assertEquals(SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
    assertEquals(SetLogic.IS_IN, childConditions[1].getSetLogic());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeNominalBinaryCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) Test(org.junit.Test)

Example 19 with TreeNominalColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeNominalColumnData in project knime-core by knime.

the class TreeNominalColumnDataTest method testUpdateChildMemberships.

/**
 * Tests the method
 * {@link TreeNominalColumnData#updateChildMemberships(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition, DataMemberships)}
 * .
 *
 * @throws Exception
 */
@Test
public void testUpdateChildMemberships() throws Exception {
    // in this case it doesn't matter if we use regression or classification (as well as binary and multiway splits)
    final TreeEnsembleLearnerConfiguration config = createConfig(true);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    final String dataCSV = "A, A, A, A, B, B, B, C, C, C, ?, ?";
    TreeNominalColumnData col = dataGen.createNominalAttributeColumn(dataCSV, "test-col", 0);
    final int[] indices = new int[12];
    final double[] weights = new double[indices.length];
    for (int i = 0; i < indices.length; i++) {
        indices[i] = i;
        weights[i] = 1.0;
    }
    final DataMemberships dataMem = new MockDataColMem(indices, indices, weights);
    TreeNodeNominalBinaryCondition binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(2), true, false);
    BitSet expected = new BitSet(12);
    BitSet inChild = col.updateChildMemberships(binCond, dataMem);
    expected.set(4, 7);
    assertEquals("The produced BitSet is incorrect.", expected, inChild);
    binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(2), true, true);
    expected.clear();
    expected.set(4, 7);
    expected.set(10, 12);
    inChild = col.updateChildMemberships(binCond, dataMem);
    assertEquals("The produced BitSet is incorrect.", expected, inChild);
    binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(2), false, false);
    expected.clear();
    expected.set(0, 4);
    expected.set(7, 10);
    inChild = col.updateChildMemberships(binCond, dataMem);
    assertEquals("The produced BitSet is incorrect.", expected, inChild);
    binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(2), false, true);
    expected.clear();
    expected.set(0, 4);
    expected.set(7, 12);
    inChild = col.updateChildMemberships(binCond, dataMem);
    assertEquals("The produced BitSet is incorrect.", expected, inChild);
    binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(5), true, false);
    expected.clear();
    expected.set(0, 4);
    expected.set(7, 10);
    inChild = col.updateChildMemberships(binCond, dataMem);
    assertEquals("The produced BitSet is incorrect.", expected, inChild);
    binCond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(5), true, true);
    expected.clear();
    expected.set(0, 4);
    expected.set(7, 12);
    inChild = col.updateChildMemberships(binCond, dataMem);
    assertEquals("The produced BitSet is incorrect.", expected, inChild);
    TreeNodeNominalCondition multiCond = new TreeNodeNominalCondition(col.getMetaData(), 0, false);
    expected.clear();
    expected.set(0, 4);
    inChild = col.updateChildMemberships(multiCond, dataMem);
    assertEquals("The produced BitSet is incorrect.", expected, inChild);
    multiCond = new TreeNodeNominalCondition(col.getMetaData(), 0, true);
    expected.clear();
    expected.set(0, 4);
    expected.set(10, 12);
    inChild = col.updateChildMemberships(multiCond, dataMem);
    assertEquals("The produced BitSet is incorrect.", expected, inChild);
    multiCond = new TreeNodeNominalCondition(col.getMetaData(), 2, false);
    expected.clear();
    expected.set(7, 10);
    inChild = col.updateChildMemberships(multiCond, dataMem);
    assertEquals("The produced BitSet is incorrect.", expected, inChild);
    multiCond = new TreeNodeNominalCondition(col.getMetaData(), 2, true);
    expected.clear();
    expected.set(7, 12);
    inChild = col.updateChildMemberships(multiCond, dataMem);
    assertEquals("The produced BitSet is incorrect.", expected, inChild);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) TreeNodeNominalCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition) TreeNodeNominalBinaryCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition) BitSet(java.util.BitSet) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) Test(org.junit.Test)

Aggregations

TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)19 Test (org.junit.Test)18 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)15 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)15 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)13 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)13 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)13 TreeNodeNominalBinaryCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition)10 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)9 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)8 RandomData (org.apache.commons.math.random.RandomData)6 TreeNodeNominalCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition)6 TestDataGenerator (org.knime.base.node.mine.treeensemble2.data.TestDataGenerator)4 TreeNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeNominalColumnData)4 BitSet (java.util.BitSet)3 PMMLCompoundPredicate (org.knime.base.node.mine.decisiontree2.PMMLCompoundPredicate)2 PMMLPredicate (org.knime.base.node.mine.decisiontree2.PMMLPredicate)2 PMMLSimplePredicate (org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate)2 PredictorRecord (org.knime.base.node.mine.treeensemble2.data.PredictorRecord)2 BigInteger (java.math.BigInteger)1