Search in sources :

Example 1 with RootDataMemberships

use of org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships in project knime-core by knime.

the class TreeLearnerRegression method learnSingleTree.

/**
 * {@inheritDoc}
 */
@Override
public TreeModelRegression learnSingleTree(final ExecutionMonitor exec, final RandomData rd) throws CanceledExecutionException {
    final TreeTargetNumericColumnData targetColumn = getTargetData();
    final TreeData data = getData();
    final RowSample rowSampling = getRowSampling();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final IDataIndexManager indexManager = getIndexManager();
    DataMemberships rootDataMemberships = new RootDataMemberships(rowSampling, data, indexManager);
    RegressionPriors targetPriors = targetColumn.getPriors(rootDataMemberships, config);
    BitSet forbiddenColumnSet = new BitSet(data.getNrAttributes());
    boolean isGradientBoosting = config instanceof GradientBoostingLearnerConfiguration;
    if (isGradientBoosting) {
        m_leafs = new ArrayList<TreeNodeRegression>();
    }
    final TreeNodeSignature rootSignature = TreeNodeSignature.ROOT_SIGNATURE;
    final ColumnSample rootColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(rootSignature);
    TreeNodeRegression rootNode = buildTreeNode(exec, 0, rootDataMemberships, rootColumnSample, getSignatureFactory().getRootSignature(), targetPriors, forbiddenColumnSet);
    assert forbiddenColumnSet.cardinality() == 0;
    rootNode.setTreeNodeCondition(TreeNodeTrueCondition.INSTANCE);
    if (isGradientBoosting) {
        return new TreeModelRegression(rootNode, m_leafs);
    }
    return new TreeModelRegression(rootNode);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) RegressionPriors(org.knime.base.node.mine.treeensemble2.data.RegressionPriors) BitSet(java.util.BitSet) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) GradientBoostingLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample)

Example 2 with RootDataMemberships

use of org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitClassificationMultiWay.

/**
 * Tests the method
 * {@link TreeNominalColumnData#calcBestSplitClassification(DataMemberships, ClassificationPriors, TreeTargetNominalColumnData, RandomData)}
 * using multiway splits
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitClassificationMultiWay() throws Exception {
    TreeEnsembleLearnerConfiguration config = createConfig(false);
    config.setUseBinaryNominalSplits(false);
    Pair<TreeNominalColumnData, TreeTargetNominalColumnData> tennisData = tennisData(config);
    TreeNominalColumnData columnData = tennisData.getFirst();
    TreeTargetNominalColumnData targetData = tennisData.getSecond();
    TreeData treeData = createTreeData(tennisData);
    assertEquals(SplitCriterion.Gini, config.getSplitCriterion());
    double[] rowWeights = new double[SMALL_COLUMN_DATA.length];
    Arrays.fill(rowWeights, 1.0);
    IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, null);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NominalMultiwaySplitCandidate.class));
    assertFalse(splitCandidate.canColumnBeSplitFurther());
    // manually via libre office calc
    assertEquals(0.0744897959, splitCandidate.getGainValue(), 0.00001);
    NominalMultiwaySplitCandidate multiWaySplitCandidate = (NominalMultiwaySplitCandidate) splitCandidate;
    TreeNodeNominalCondition[] childConditions = multiWaySplitCandidate.getChildConditions();
    assertEquals(3, childConditions.length);
    assertEquals("S", childConditions[0].getValue());
    assertEquals("O", childConditions[1].getValue());
    assertEquals("R", childConditions[2].getValue());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeNominalCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) Test(org.junit.Test)

Example 3 with RootDataMemberships

use of org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitCassificationBinaryTwoClassXGBoostMissingValue.

/**
 * Tests the XGBoost Missing value handling in case of a two class problem <br>
 * currently not tested because missing value handling will probably be implemented differently.
 *
 * @throws Exception
 */
// @Test
public void testCalcBestSplitCassificationBinaryTwoClassXGBoostMissingValue() throws Exception {
    final TreeEnsembleLearnerConfiguration config = createConfig(false);
    config.setMissingValueHandling(MissingValueHandling.XGBoost);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    // check correct behavior if no missing values are encountered during split search
    Pair<TreeNominalColumnData, TreeTargetNominalColumnData> twoClassTennisData = twoClassTennisData(config);
    TreeData treeData = dataGen.createTreeData(twoClassTennisData.getSecond(), twoClassTennisData.getFirst());
    IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
    double[] rowWeights = new double[TWO_CLASS_INDICES.length];
    Arrays.fill(rowWeights, 1.0);
    // DataMemberships dataMemberships = TestDataGenerator.createMockDataMemberships(TWO_CLASS_INDICES.length);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    TreeTargetNominalColumnData targetData = twoClassTennisData.getSecond();
    TreeNominalColumnData columnData = twoClassTennisData.getFirst();
    ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
    RandomData rd = TestDataGenerator.createRandomData();
    SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, rd);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NominalBinarySplitCandidate.class));
    NominalBinarySplitCandidate binarySplitCandidate = (NominalBinarySplitCandidate) splitCandidate;
    TreeNodeNominalBinaryCondition[] childConditions = binarySplitCandidate.getChildConditions();
    assertEquals(2, childConditions.length);
    assertArrayEquals(new String[] { "R" }, childConditions[0].getValues());
    assertArrayEquals(new String[] { "R" }, childConditions[1].getValues());
    assertEquals(SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
    assertEquals(SetLogic.IS_IN, childConditions[1].getSetLogic());
    // check if missing values go left
    assertTrue(childConditions[0].acceptsMissings());
    assertFalse(childConditions[1].acceptsMissings());
    // check correct behavior if missing values are encountered during split search
    String dataContainingMissingsCSV = "S,?,O,R,S,R,S,O,O,?";
    columnData = dataGen.createNominalAttributeColumn(dataContainingMissingsCSV, "column containing missing values", 0);
    treeData = dataGen.createTreeData(targetData, columnData);
    indexManager = new DefaultDataIndexManager(treeData);
    dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, null);
    assertNotNull(splitCandidate);
    binarySplitCandidate = (NominalBinarySplitCandidate) splitCandidate;
    assertEquals("Gain was not as expected", 0.08, binarySplitCandidate.getGainValue(), 1e-8);
    childConditions = binarySplitCandidate.getChildConditions();
    String[] conditionValues = new String[] { "O", "?" };
    assertArrayEquals("Values in nominal condition did not match", conditionValues, childConditions[0].getValues());
    assertArrayEquals("Values in nominal condition did not match", conditionValues, childConditions[1].getValues());
    assertEquals("Wrong set logic.", SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
    assertEquals("Wrong set logic.", SetLogic.IS_IN, childConditions[1].getSetLogic());
    assertFalse("Missig values are not sent to the correct child.", childConditions[0].acceptsMissings());
    assertTrue("Missig values are not sent to the correct child.", childConditions[1].acceptsMissings());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) RandomData(org.apache.commons.math.random.RandomData) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeNominalBinaryCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)

Example 4 with RootDataMemberships

use of org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships in project knime-core by knime.

the class TreeLearnerClassification method learnSingleTreeRecursive.

private TreeModelClassification learnSingleTreeRecursive(final ExecutionMonitor exec, final RandomData rd) throws CanceledExecutionException {
    final TreeData data = getData();
    final RowSample rowSampling = getRowSampling();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
    final // new RootDataMem(rowSampling, getIndexManager());
    DataMemberships rootDataMemberships = new RootDataMemberships(rowSampling, data, getIndexManager());
    ClassificationPriors targetPriors = targetColumn.getDistribution(rootDataMemberships, config);
    BitSet forbiddenColumnSet = new BitSet(data.getNrAttributes());
    // final DataMemberships rootDataMemberships = new IntArrayDataMemberships(sampleWeights, data);
    final TreeNodeSignature rootSignature = TreeNodeSignature.ROOT_SIGNATURE;
    final ColumnSample rootColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(rootSignature);
    TreeNodeClassification rootNode = null;
    rootNode = buildTreeNode(exec, 0, rootDataMemberships, rootColumnSample, rootSignature, targetPriors, forbiddenColumnSet);
    assert forbiddenColumnSet.cardinality() == 0;
    rootNode.setTreeNodeCondition(TreeNodeTrueCondition.INSTANCE);
    return new TreeModelClassification(rootNode);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeClassification(org.knime.base.node.mine.treeensemble2.model.TreeNodeClassification) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) BitSet(java.util.BitSet) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeTargetNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData) ClassificationPriors(org.knime.base.node.mine.treeensemble2.data.ClassificationPriors) TreeModelClassification(org.knime.base.node.mine.treeensemble2.model.TreeModelClassification)

Example 5 with RootDataMemberships

use of org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships in project knime-core by knime.

the class TreeNumericColumnDataTest method testCalcBestSplitClassification.

@Test
public void testCalcBestSplitClassification() throws Exception {
    TreeEnsembleLearnerConfiguration config = createConfig();
    /* data from J. Fuernkranz, Uni Darmstadt:
         * http://www.ke.tu-darmstadt.de/lehre/archiv/ws0809/mldm/dt.pdf */
    final double[] data = asDataArray("60,70,75,85, 90, 95, 100,120,125,220");
    final String[] target = asStringArray("No,No,No,Yes,Yes,Yes,No, No, No, No");
    Pair<TreeOrdinaryNumericColumnData, TreeTargetNominalColumnData> exampleData = exampleData(config, data, target);
    RandomData rd = config.createRandomData();
    TreeNumericColumnData columnData = exampleData.getFirst();
    TreeTargetNominalColumnData targetData = exampleData.getSecond();
    assertEquals(SplitCriterion.Gini, config.getSplitCriterion());
    double[] rowWeights = new double[data.length];
    Arrays.fill(rowWeights, 1.0);
    TreeData treeData = createTreeDataClassification(exampleData);
    IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, rd);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NumericSplitCandidate.class));
    assertTrue(splitCandidate.canColumnBeSplitFurther());
    // libre office calc
    assertEquals(/*0.42 - 0.300 */
    0.12, splitCandidate.getGainValue(), 0.00001);
    NumericSplitCandidate numSplitCandidate = (NumericSplitCandidate) splitCandidate;
    TreeNodeNumericCondition[] childConditions = numSplitCandidate.getChildConditions();
    assertEquals(2, childConditions.length);
    assertEquals((95.0 + 100.0) / 2.0, childConditions[0].getSplitValue(), 0.0);
    assertEquals((95.0 + 100.0) / 2.0, childConditions[1].getSplitValue(), 0.0);
    assertEquals(NumericOperator.LessThanOrEqual, childConditions[0].getNumericOperator());
    assertEquals(NumericOperator.LargerThan, childConditions[1].getNumericOperator());
    double[] childRowWeights = new double[data.length];
    System.arraycopy(rowWeights, 0, childRowWeights, 0, rowWeights.length);
    BitSet inChild = columnData.updateChildMemberships(childConditions[0], dataMemberships);
    DataMemberships childMemberships = dataMemberships.createChildMemberships(inChild);
    ClassificationPriors childTargetPriors = targetData.getDistribution(childMemberships, config);
    SplitCandidate splitCandidateChild = columnData.calcBestSplitClassification(childMemberships, childTargetPriors, targetData, rd);
    assertNotNull(splitCandidateChild);
    assertThat(splitCandidateChild, instanceOf(NumericSplitCandidate.class));
    // manually via libre office calc
    assertEquals(0.5, splitCandidateChild.getGainValue(), 0.00001);
    TreeNodeNumericCondition[] childConditions2 = ((NumericSplitCandidate) splitCandidateChild).getChildConditions();
    assertEquals(2, childConditions2.length);
    assertEquals((75.0 + 85.0) / 2.0, childConditions2[0].getSplitValue(), 0.0);
    System.arraycopy(rowWeights, 0, childRowWeights, 0, rowWeights.length);
    inChild = columnData.updateChildMemberships(childConditions[1], dataMemberships);
    childMemberships = dataMemberships.createChildMemberships(inChild);
    childTargetPriors = targetData.getDistribution(childMemberships, config);
    splitCandidateChild = columnData.calcBestSplitClassification(childMemberships, childTargetPriors, targetData, rd);
    assertNull(splitCandidateChild);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) RandomData(org.apache.commons.math.random.RandomData) TreeNodeNumericCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition) BitSet(java.util.BitSet) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) NumericMissingSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) Test(org.junit.Test)

Aggregations

TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)18 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)16 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)15 Test (org.junit.Test)14 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)14 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)12 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)12 BitSet (java.util.BitSet)8 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)7 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)7 RandomData (org.apache.commons.math.random.RandomData)6 NumericMissingSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate)5 NumericSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate)5 TreeNodeNominalBinaryCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition)5 TreeNodeNumericCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition)5 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)4 RowSample (org.knime.base.node.mine.treeensemble2.sample.row.RowSample)4 TestDataGenerator (org.knime.base.node.mine.treeensemble2.data.TestDataGenerator)2 TreeNodeNominalCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition)2 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)2