Search in sources :

Example 11 with TreeNodeCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.

the class TreeNumericColumnDataTest method testXGBoostMissingValueHandling.

/**
 * This method tests if the conditions for child nodes are correct in case of XGBoostMissingValueHandling
 *
 * @throws Exception
 */
@Test
public void testXGBoostMissingValueHandling() throws Exception {
    TreeEnsembleLearnerConfiguration config = createConfig();
    config.setMissingValueHandling(MissingValueHandling.XGBoost);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    final RandomData rd = config.createRandomData();
    final int[] indices = new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
    final double[] weights = new double[10];
    Arrays.fill(weights, 1.0);
    final MockDataColMem dataMem = new MockDataColMem(indices, indices, weights);
    final String dataCSV = "1,2,2,3,4,5,6,7,NaN,NaN";
    final String target1CSV = "A,A,A,A,B,B,B,B,A,A";
    final String target2CSV = "A,A,A,A,B,B,B,B,B,B";
    final double expectedGain = 0.48;
    final TreeNumericColumnData col = dataGen.createNumericAttributeColumn(dataCSV, "testCol", 0);
    final TreeTargetNominalColumnData target1 = TestDataGenerator.createNominalTargetColumn(target1CSV);
    final SplitCandidate split1 = col.calcBestSplitClassification(dataMem, target1.getDistribution(weights, config), target1, rd);
    assertEquals("Wrong gain.", expectedGain, split1.getGainValue(), 1e-8);
    final TreeNodeCondition[] childConds1 = split1.getChildConditions();
    final TreeNodeNumericCondition numCondLeft1 = (TreeNodeNumericCondition) childConds1[0];
    assertEquals("Wrong split point.", 3.5, numCondLeft1.getSplitValue(), 1e-8);
    assertTrue("Missings were not sent in the correct direction.", numCondLeft1.acceptsMissings());
    final TreeNodeNumericCondition numCondRight1 = (TreeNodeNumericCondition) childConds1[1];
    assertEquals("Wrong split point.", 3.5, numCondRight1.getSplitValue(), 1e-8);
    assertFalse("Missings were not sent in the correct direction.", numCondRight1.acceptsMissings());
    final TreeTargetNominalColumnData target2 = TestDataGenerator.createNominalTargetColumn(target2CSV);
    final SplitCandidate split2 = col.calcBestSplitClassification(dataMem, target2.getDistribution(weights, config), target2, rd);
    assertEquals("Wrong gain.", expectedGain, split2.getGainValue(), 1e-8);
    final TreeNodeCondition[] childConds2 = split2.getChildConditions();
    final TreeNodeNumericCondition numCondLeft2 = (TreeNodeNumericCondition) childConds2[0];
    assertEquals("Wrong split point.", 3.5, numCondLeft2.getSplitValue(), 1e-8);
    assertFalse("Missings were not sent in the correct direction.", numCondLeft2.acceptsMissings());
    final TreeNodeNumericCondition numCondRight2 = (TreeNodeNumericCondition) childConds2[1];
    assertEquals("Wrong split point.", 3.5, numCondRight2.getSplitValue(), 1e-8);
    assertTrue("Missings were not sent in the correct direction.", numCondRight2.acceptsMissings());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeNodeNumericCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) NumericMissingSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition) Test(org.junit.Test)

Example 12 with TreeNodeCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.

the class TreeNumericColumnDataTest method testCalcBestSplitRegression.

@Test
public void testCalcBestSplitRegression() throws InvalidSettingsException {
    String dataCSV = "1,2,3,4,5,6,7,8,9,10";
    String targetCSV = "1,5,4,4.3,6.5,6.5,4,3,3,4";
    TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(true);
    config.setNrModels(1);
    config.setDataSelectionWithReplacement(false);
    config.setUseDifferentAttributesAtEachNode(false);
    config.setDataFractionPerTree(1.0);
    config.setColumnSamplingMode(ColumnSamplingMode.None);
    TestDataGenerator dataGen = new TestDataGenerator(config);
    RandomData rd = config.createRandomData();
    TreeTargetNumericColumnData target = TestDataGenerator.createNumericTargetColumn(targetCSV);
    TreeNumericColumnData attribute = dataGen.createNumericAttributeColumn(dataCSV, "test-col", 0);
    TreeData data = new TreeData(new TreeAttributeColumnData[] { attribute }, target, TreeType.Ordinary);
    double[] weights = new double[10];
    Arrays.fill(weights, 1.0);
    DataMemberships rootMem = new RootDataMemberships(weights, data, new DefaultDataIndexManager(data));
    SplitCandidate firstSplit = attribute.calcBestSplitRegression(rootMem, target.getPriors(rootMem, config), target, rd);
    // calculated via OpenOffice calc
    assertEquals(10.885444, firstSplit.getGainValue(), 1e-5);
    TreeNodeCondition[] firstConditions = firstSplit.getChildConditions();
    assertEquals(2, firstConditions.length);
    for (int i = 0; i < firstConditions.length; i++) {
        assertThat(firstConditions[i], instanceOf(TreeNodeNumericCondition.class));
        TreeNodeNumericCondition numCond = (TreeNodeNumericCondition) firstConditions[i];
        assertEquals(1.5, numCond.getSplitValue(), 0);
    }
    // left child contains only one row therefore only look at right child
    BitSet expectedInChild = new BitSet(10);
    expectedInChild.set(1, 10);
    BitSet inChild = attribute.updateChildMemberships(firstConditions[1], rootMem);
    assertEquals(expectedInChild, inChild);
    DataMemberships childMem = rootMem.createChildMemberships(inChild);
    SplitCandidate secondSplit = attribute.calcBestSplitRegression(childMem, target.getPriors(childMem, config), target, rd);
    assertEquals(6.883555, secondSplit.getGainValue(), 1e-5);
    TreeNodeCondition[] secondConditions = secondSplit.getChildConditions();
    for (int i = 0; i < secondConditions.length; i++) {
        assertThat(secondConditions[i], instanceOf(TreeNodeNumericCondition.class));
        TreeNodeNumericCondition numCond = (TreeNodeNumericCondition) secondConditions[i];
        assertEquals(6.5, numCond.getSplitValue(), 0);
    }
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) RandomData(org.apache.commons.math.random.RandomData) TreeNodeNumericCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition) BitSet(java.util.BitSet) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) NumericMissingSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition) Test(org.junit.Test)

Example 13 with TreeNodeCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.

the class TreeNodeClassification method createDecisionTreeNode.

/**
 * Creates DecisionTreeNode model that is used in Decision Tree of KNIME
 *
 * @param idGenerator
 * @param metaData
 * @return a DecisionTreeNode
 */
public DecisionTreeNode createDecisionTreeNode(final MutableInteger idGenerator, final TreeMetaData metaData) {
    DataCell majorityCell = new StringCell(getMajorityClassName());
    final float[] targetDistribution = getTargetDistribution();
    int initSize = (int) (targetDistribution.length / 0.75 + 1.0);
    LinkedHashMap<DataCell, Double> scoreDistributionMap = new LinkedHashMap<DataCell, Double>(initSize);
    NominalValueRepresentation[] targets = getTargetMetaData().getValues();
    for (int i = 0; i < targetDistribution.length; i++) {
        String cl = targets[i].getNominalValue();
        double d = targetDistribution[i];
        scoreDistributionMap.put(new StringCell(cl), d);
    }
    final int nrChildren = getNrChildren();
    if (nrChildren == 0) {
        return new DecisionTreeNodeLeaf(idGenerator.inc(), majorityCell, scoreDistributionMap);
    } else {
        int id = idGenerator.inc();
        DecisionTreeNode[] childNodes = new DecisionTreeNode[nrChildren];
        int splitAttributeIndex = getSplitAttributeIndex();
        assert splitAttributeIndex >= 0 : "non-leaf node has no split";
        String splitAttribute = metaData.getAttributeMetaData(splitAttributeIndex).getAttributeName();
        PMMLPredicate[] childPredicates = new PMMLPredicate[nrChildren];
        for (int i = 0; i < nrChildren; i++) {
            final TreeNodeClassification treeNode = getChild(i);
            TreeNodeCondition cond = treeNode.getCondition();
            childPredicates[i] = cond.toPMMLPredicate();
            childNodes[i] = treeNode.createDecisionTreeNode(idGenerator, metaData);
        }
        return new DecisionTreeNodeSplitPMML(id, majorityCell, scoreDistributionMap, splitAttribute, childPredicates, childNodes);
    }
}
Also used : NominalValueRepresentation(org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation) PMMLPredicate(org.knime.base.node.mine.decisiontree2.PMMLPredicate) LinkedHashMap(java.util.LinkedHashMap) DecisionTreeNodeLeaf(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf) StringCell(org.knime.core.data.def.StringCell) DecisionTreeNodeSplitPMML(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML) DataCell(org.knime.core.data.DataCell) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Example 14 with TreeNodeCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.

the class AbstractTreeModelExporter method addTreeNode.

/**
 * @param pmmlNode
 * @param node
 */
@SuppressWarnings("unchecked")
private void addTreeNode(final Node pmmlNode, final T node) {
    int index = m_nodeIndex;
    m_nodeIndex++;
    pmmlNode.setId(Integer.toString(index));
    addNodeContent(index, pmmlNode, node);
    TreeNodeCondition condition = node.getCondition();
    m_conditionExporter.exportCondition(condition, pmmlNode);
    for (int i = 0; i < node.getNrChildren(); i++) {
        addTreeNode(pmmlNode.addNewNode(), (T) node.getChild(i));
    }
}
Also used : TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)

Example 15 with TreeNodeCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.

the class ConditionExporter method exportCondition.

void exportCondition(final TreeNodeCondition condition, final Node pmmlNode) {
    if (condition instanceof TreeNodeTrueCondition) {
        pmmlNode.addNewTrue();
    } else if (condition instanceof TreeNodeColumnCondition) {
        final TreeNodeColumnCondition colCondition = (TreeNodeColumnCondition) condition;
        exportColumnCondition(colCondition, pmmlNode);
    } else if (condition instanceof AbstractTreeNodeSurrogateCondition) {
        final AbstractTreeNodeSurrogateCondition surrogateCond = (AbstractTreeNodeSurrogateCondition) condition;
        setValuesFromPMMLCompoundPredicate(pmmlNode.addNewCompoundPredicate(), surrogateCond.toPMMLPredicate());
    } else {
        throw new IllegalStateException("Unsupported condition (not implemented): " + condition.getClass().getSimpleName());
    }
}
Also used : TreeNodeColumnCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeColumnCondition) TreeNodeTrueCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeTrueCondition) AbstractTreeNodeSurrogateCondition(org.knime.base.node.mine.treeensemble2.model.AbstractTreeNodeSurrogateCondition)

Aggregations

TreeNodeCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)13 BitSet (java.util.BitSet)9 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)5 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)5 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)5 ArrayList (java.util.ArrayList)4 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)4 TreeNodeNumericCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition)4 RandomData (org.apache.commons.math.random.RandomData)3 Test (org.junit.Test)3 NominalValueRepresentation (org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation)3 TreeNodeColumnCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeColumnCondition)3 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)3 ClassificationPriors (org.knime.base.node.mine.treeensemble2.data.ClassificationPriors)2 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)2 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)2 ColumnMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships)2 NumericMissingSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate)2 NumericSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate)2 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)2