Search in sources :

Example 6 with PMMLSimpleSetPredicate

use of org.knime.base.node.mine.decisiontree2.PMMLSimpleSetPredicate in project knime-core by knime.

the class TreeModelPMMLTranslator method setValuesFromPMMLCompoundPredicate.

private static void setValuesFromPMMLCompoundPredicate(final CompoundPredicate to, final PMMLCompoundPredicate from) {
    final PMMLBooleanOperator boolOp = from.getBooleanOperator();
    switch(boolOp) {
        case AND:
            to.setBooleanOperator(CompoundPredicate.BooleanOperator.AND);
            break;
        case OR:
            to.setBooleanOperator(CompoundPredicate.BooleanOperator.OR);
            break;
        case SURROGATE:
            to.setBooleanOperator(CompoundPredicate.BooleanOperator.SURROGATE);
            break;
        case XOR:
            to.setBooleanOperator(CompoundPredicate.BooleanOperator.XOR);
            break;
        default:
            throw new IllegalStateException("Unknown boolean predicate \"" + boolOp + "\".");
    }
    final List<PMMLPredicate> predicates = from.getPredicates();
    for (final PMMLPredicate predicate : predicates) {
        if (predicate instanceof PMMLSimplePredicate) {
            setValuesFromPMMLSimplePredicate(to.addNewSimplePredicate(), (PMMLSimplePredicate) predicate);
        } else if (predicate instanceof PMMLSimpleSetPredicate) {
            setValuesFromPMMLSimpleSetPredicate(to.addNewSimpleSetPredicate(), (PMMLSimpleSetPredicate) predicate);
        } else if (predicate instanceof PMMLTruePredicate) {
            to.addNewTrue();
        } else if (predicate instanceof PMMLFalsePredicate) {
            to.addNewFalse();
        } else if (predicate instanceof PMMLCompoundPredicate) {
            final CompoundPredicate compound = to.addNewCompoundPredicate();
            final PMMLCompoundPredicate knimeCompound = (PMMLCompoundPredicate) predicate;
            setValuesFromPMMLCompoundPredicate(compound, knimeCompound);
        } else {
            throw new IllegalStateException("Unknown predicate type \"" + predicate + "\".");
        }
    }
}
Also used : PMMLTruePredicate(org.knime.base.node.mine.decisiontree2.PMMLTruePredicate) PMMLSimpleSetPredicate(org.knime.base.node.mine.decisiontree2.PMMLSimpleSetPredicate) PMMLSimplePredicate(org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate) PMMLCompoundPredicate(org.knime.base.node.mine.decisiontree2.PMMLCompoundPredicate) CompoundPredicate(org.dmg.pmml.CompoundPredicateDocument.CompoundPredicate) PMMLPredicate(org.knime.base.node.mine.decisiontree2.PMMLPredicate) PMMLFalsePredicate(org.knime.base.node.mine.decisiontree2.PMMLFalsePredicate) PMMLBooleanOperator(org.knime.base.node.mine.decisiontree2.PMMLBooleanOperator) PMMLCompoundPredicate(org.knime.base.node.mine.decisiontree2.PMMLCompoundPredicate)

Example 7 with PMMLSimpleSetPredicate

use of org.knime.base.node.mine.decisiontree2.PMMLSimpleSetPredicate in project knime-core by knime.

the class TreeNodeNominalBinaryConditionTest method testToPMMLPredicate.

/**
 * This method tests the
 * {@link TreeNodeNominalBinaryCondition#toPMMLPredicate()} method.
 *
 * @throws Exception
 */
@Test
public void testToPMMLPredicate() throws Exception {
    final TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    final TreeNominalColumnData col = dataGen.createNominalAttributeColumn("A,A,B,C,C,D", "testcol", 0);
    TreeNodeNominalBinaryCondition cond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(1), true, false);
    PMMLPredicate predicate = cond.toPMMLPredicate();
    assertThat(predicate, instanceOf(PMMLSimpleSetPredicate.class));
    PMMLSimpleSetPredicate setPredicate = (PMMLSimpleSetPredicate) predicate;
    assertEquals("Wrong attribute", col.getMetaData().getAttributeName(), setPredicate.getSplitAttribute());
    assertEquals("Wrong set predicate", PMMLSetOperator.IS_IN, setPredicate.getSetOperator());
    assertArrayEquals("Wrong values", new String[] { "A" }, setPredicate.getValues().toArray(new String[1]));
    cond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(2), false, true);
    predicate = cond.toPMMLPredicate();
    assertEquals("Wrong attribute", col.getMetaData().getAttributeName(), predicate.getSplitAttribute());
    assertThat(predicate, instanceOf(PMMLCompoundPredicate.class));
    PMMLCompoundPredicate compoundPredicate = (PMMLCompoundPredicate) predicate;
    assertEquals("Wrong boolean operator", PMMLBooleanOperator.OR, compoundPredicate.getBooleanOperator());
    LinkedList<PMMLPredicate> preds = compoundPredicate.getPredicates();
    assertEquals("Number of predicates did not match.", 2, preds.size());
    assertThat(preds.get(0), instanceOf(PMMLSimpleSetPredicate.class));
    setPredicate = (PMMLSimpleSetPredicate) preds.get(0);
    assertEquals("Wrong attribute", col.getMetaData().getAttributeName(), setPredicate.getSplitAttribute());
    assertEquals("Wrong set predicate", PMMLSetOperator.IS_NOT_IN, setPredicate.getSetOperator());
    assertArrayEquals("Wrong values", new String[] { "B" }, setPredicate.getValues().toArray(new String[1]));
    assertThat(preds.get(1), instanceOf(PMMLSimplePredicate.class));
    PMMLSimplePredicate simplePredicate = (PMMLSimplePredicate) preds.get(1);
    assertEquals("Should be isMissing", PMMLOperator.IS_MISSING, simplePredicate.getOperator());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) PMMLSimpleSetPredicate(org.knime.base.node.mine.decisiontree2.PMMLSimpleSetPredicate) PMMLSimplePredicate(org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate) PMMLPredicate(org.knime.base.node.mine.decisiontree2.PMMLPredicate) TreeNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeNominalColumnData) TestDataGenerator(org.knime.base.node.mine.treeensemble2.data.TestDataGenerator) PMMLCompoundPredicate(org.knime.base.node.mine.decisiontree2.PMMLCompoundPredicate) Test(org.junit.Test)

Example 8 with PMMLSimpleSetPredicate

use of org.knime.base.node.mine.decisiontree2.PMMLSimpleSetPredicate in project knime-core by knime.

the class PMMLDecisionTreeTranslator method addTreeNode.

/**
 * A recursive function which converts each KNIME Tree node to a
 * corresponding PMML element.
 *
 * @param pmmlNode the desired PMML element
 * @param node A KNIME DecisionTree node
 */
private static void addTreeNode(final NodeDocument.Node pmmlNode, final DecisionTreeNode node, final DerivedFieldMapper mapper) {
    pmmlNode.setId(String.valueOf(node.getOwnIndex()));
    pmmlNode.setScore(node.getMajorityClass().toString());
    // read in and then exported again
    if (node.getEntireClassCount() > 0) {
        pmmlNode.setRecordCount(node.getEntireClassCount());
    }
    if (node instanceof DecisionTreeNodeSplitPMML) {
        int defaultChild = ((DecisionTreeNodeSplitPMML) node).getDefaultChildIndex();
        if (defaultChild > -1) {
            pmmlNode.setDefaultChild(String.valueOf(defaultChild));
        }
    }
    // adding score and stuff from parent
    DecisionTreeNode parent = node.getParent();
    if (parent == null) {
        // When the parent is null, it is the root Node.
        // For root node, the predicate is always True.
        pmmlNode.addNewTrue();
    } else if (parent instanceof DecisionTreeNodeSplitContinuous) {
        // SimplePredicate case
        DecisionTreeNodeSplitContinuous splitNode = (DecisionTreeNodeSplitContinuous) parent;
        if (splitNode.getIndex(node) == 0) {
            SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
            pmmlSimplePredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
            pmmlSimplePredicate.setOperator(Operator.LESS_OR_EQUAL);
            pmmlSimplePredicate.setValue(String.valueOf(splitNode.getThreshold()));
        } else if (splitNode.getIndex(node) == 1) {
            pmmlNode.addNewTrue();
        }
    } else if (parent instanceof DecisionTreeNodeSplitNominalBinary) {
        // SimpleSetPredicate case
        DecisionTreeNodeSplitNominalBinary splitNode = (DecisionTreeNodeSplitNominalBinary) parent;
        SimpleSetPredicate pmmlSimpleSetPredicate = pmmlNode.addNewSimpleSetPredicate();
        pmmlSimpleSetPredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
        pmmlSimpleSetPredicate.setBooleanOperator(SimpleSetPredicate.BooleanOperator.IS_IN);
        ArrayType pmmlArray = pmmlSimpleSetPredicate.addNewArray();
        pmmlArray.setType(ArrayType.Type.STRING);
        DataCell[] splitValues = splitNode.getSplitValues();
        List<Integer> indices = null;
        if (splitNode.getIndex(node) == SplitNominalBinary.LEFT_PARTITION) {
            indices = splitNode.getLeftChildIndices();
        } else if (splitNode.getIndex(node) == SplitNominalBinary.RIGHT_PARTITION) {
            indices = splitNode.getRightChildIndices();
        } else {
            throw new IllegalArgumentException("Split node is neither " + "contained in the right nor in the left partition.");
        }
        StringBuilder classSet = new StringBuilder();
        for (Integer i : indices) {
            if (classSet.length() > 0) {
                classSet.append(" ");
            }
            classSet.append(splitValues[i].toString());
        }
        pmmlArray.setN(BigInteger.valueOf(indices.size()));
        XmlCursor xmlCursor = pmmlArray.newCursor();
        xmlCursor.setTextValue(classSet.toString());
        xmlCursor.dispose();
    } else if (parent instanceof DecisionTreeNodeSplitNominal) {
        DecisionTreeNodeSplitNominal splitNode = (DecisionTreeNodeSplitNominal) parent;
        SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
        pmmlSimplePredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
        pmmlSimplePredicate.setOperator(Operator.EQUAL);
        int nodeIndex = parent.getIndex(node);
        pmmlSimplePredicate.setValue(String.valueOf(splitNode.getSplitValues()[nodeIndex].toString()));
    } else if (parent instanceof DecisionTreeNodeSplitPMML) {
        DecisionTreeNodeSplitPMML splitNode = (DecisionTreeNodeSplitPMML) parent;
        int nodeIndex = parent.getIndex(node);
        // get the PMML predicate of the current node from its parent
        PMMLPredicate predicate = splitNode.getSplitPred()[nodeIndex];
        if (predicate instanceof PMMLCompoundPredicate) {
            // surrogates as used in GBT
            exportCompoundPredicate(pmmlNode, (PMMLCompoundPredicate) predicate, mapper);
        } else {
            predicate.setSplitAttribute(mapper.getDerivedFieldName(predicate.getSplitAttribute()));
            // delegate the writing to the predicate translator
            PMMLPredicateTranslator.exportTo(predicate, pmmlNode);
        }
    } else {
        throw new IllegalArgumentException("Node Type " + parent.getClass() + " is not supported!");
    }
    // adding score distribution (class counts)
    Set<Entry<DataCell, Double>> classCounts = node.getClassCounts().entrySet();
    Iterator<Entry<DataCell, Double>> iterator = classCounts.iterator();
    while (iterator.hasNext()) {
        Entry<DataCell, Double> entry = iterator.next();
        DataCell cell = entry.getKey();
        Double freq = entry.getValue();
        ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
        pmmlScoreDist.setValue(cell.toString());
        pmmlScoreDist.setRecordCount(freq);
    }
    // adding children
    if (!(node instanceof DecisionTreeNodeLeaf)) {
        for (int i = 0; i < node.getChildCount(); i++) {
            addTreeNode(pmmlNode.addNewNode(), node.getChildAt(i), mapper);
        }
    }
}
Also used : DecisionTreeNodeSplitNominal(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitNominal) ArrayType(org.dmg.pmml.ArrayType) Entry(java.util.Map.Entry) DecisionTreeNodeSplitPMML(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML) SimplePredicate(org.dmg.pmml.SimplePredicateDocument.SimplePredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicateDocument.SimpleSetPredicate) XmlCursor(org.apache.xmlbeans.XmlCursor) BigInteger(java.math.BigInteger) ScoreDistribution(org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution) DecisionTreeNodeLeaf(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf) DecisionTreeNodeSplitNominalBinary(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitNominalBinary) DataCell(org.knime.core.data.DataCell) DecisionTreeNodeSplitContinuous(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitContinuous) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Example 9 with PMMLSimpleSetPredicate

use of org.knime.base.node.mine.decisiontree2.PMMLSimpleSetPredicate in project knime-core by knime.

the class DecisionTreeLearnerNodeModel method buildTree.

/**
 * Recursively induces the decision tree.
 *
 * @param table the {@link InMemoryTable} representing the data for this
 *            node to determine the split and after that perform
 *            partitioning
 * @param exec the execution context for progress information
 * @param depth the current recursion depth
 */
private DecisionTreeNode buildTree(final InMemoryTable table, final ExecutionContext exec, final int depth, final SplitQualityMeasure splitQualityMeasure, final ParallelProcessing parallelProcessing) throws CanceledExecutionException, IllegalAccessException {
    exec.checkCanceled();
    // derive this node's id from the counter
    int nodeId = m_counter.getAndIncrement();
    DataCell majorityClass = table.getMajorityClassAsCell();
    LinkedHashMap<DataCell, Double> frequencies = table.getClassFrequencies();
    // if the distribution allows for a leaf
    if (table.isPureEnough()) {
        // free memory
        table.freeUnderlyingDataRows();
        double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
        exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
        return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
    } else {
        // find the best splits for all attributes
        SplitFinder splittFinder = new SplitFinder(table, splitQualityMeasure, m_averageSplitpoint.getBooleanValue(), m_minNumberRecordsPerNode.getIntValue(), m_binaryNominalSplitMode.getBooleanValue(), m_maxNumNominalsForCompleteComputation.getIntValue());
        // check for enough memory
        checkMemory();
        // get the best split among the best attribute splits
        Split split = splittFinder.getSplit();
        // if no best split could be evaluated, create a leaf node
        if (split == null || !split.isValidSplit()) {
            table.freeUnderlyingDataRows();
            double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
            exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
            return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
        }
        // partition the attribute lists according to this split
        Partitioner partitioner = new Partitioner(table, split, m_minNumberRecordsPerNode.getIntValue());
        if (!partitioner.couldBeUsefulPartitioned()) {
            table.freeUnderlyingDataRows();
            double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
            exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
            return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
        }
        // get the just created partitions
        InMemoryTable[] partitionTables = partitioner.getPartitionTables();
        // recursively build the  child nodes
        DecisionTreeNode[] children = new DecisionTreeNode[partitionTables.length];
        ArrayList<ParallelBuilding> threads = new ArrayList<ParallelBuilding>();
        int i = 0;
        for (InMemoryTable partitionTable : partitionTables) {
            exec.checkCanceled();
            if (partitionTable.getNumberDataRows() * m_numberAttributes < 10000 || !parallelProcessing.isThreadAvailable()) {
                children[i] = buildTree(partitionTable, exec, depth + 1, splitQualityMeasure, parallelProcessing);
            } else {
                String threadName = "Build thread, node: " + nodeId + "." + i;
                ParallelBuilding buildThread = new ParallelBuilding(threadName, partitionTable, exec, depth + 1, i, splitQualityMeasure, parallelProcessing);
                LOGGER.debug("Start new parallel building thread: " + threadName);
                threads.add(buildThread);
                buildThread.start();
            }
            i++;
        }
        // already assigned to the child array
        for (ParallelBuilding buildThread : threads) {
            children[buildThread.getThreadIndex()] = buildThread.getResultNode();
            exec.checkCanceled();
            if (buildThread.getException() != null) {
                for (ParallelBuilding buildThread2 : threads) {
                    buildThread2.stop();
                }
                throw new RuntimeException(buildThread.getException().getMessage());
            }
        }
        threads.clear();
        if (split instanceof SplitContinuous) {
            double splitValue = ((SplitContinuous) split).getBestSplitValue();
            // return new DecisionTreeNodeSplitContinuous(nodeId,
            // majorityClass, frequencies, split
            // .getSplitAttributeName(), children, splitValue);
            String splitAttribute = split.getSplitAttributeName();
            PMMLPredicate[] splitPredicates = new PMMLPredicate[] { new PMMLSimplePredicate(splitAttribute, PMMLOperator.LESS_OR_EQUAL, Double.toString(splitValue)), new PMMLSimplePredicate(splitAttribute, PMMLOperator.GREATER_THAN, Double.toString(splitValue)) };
            return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
        } else if (split instanceof SplitNominalNormal) {
            // else the attribute is nominal
            DataCell[] splitValues = ((SplitNominalNormal) split).getSplitValues();
            // return new DecisionTreeNodeSplitNominal(nodeId, majorityClass,
            // frequencies, split.getSplitAttributeName(),
            // splitValues, children);
            int num = children.length;
            PMMLPredicate[] splitPredicates = new PMMLPredicate[num];
            String splitAttribute = split.getSplitAttributeName();
            for (int j = 0; j < num; j++) {
                splitPredicates[j] = new PMMLSimplePredicate(splitAttribute, PMMLOperator.EQUAL, splitValues[j].toString());
            }
            return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
        } else {
            // binary nominal
            SplitNominalBinary splitNominalBinary = (SplitNominalBinary) split;
            DataCell[] splitValues = splitNominalBinary.getSplitValues();
            // return new DecisionTreeNodeSplitNominalBinary(nodeId,
            // majorityClass, frequencies, split
            // .getSplitAttributeName(), splitValues,
            // splitNominalBinary.getIntMappingsLeftPartition(),
            // splitNominalBinary.getIntMappingsRightPartition(),
            // children/* children[0]=left, ..[1] right */);
            String splitAttribute = split.getSplitAttributeName();
            int[][] indices = new int[][] { splitNominalBinary.getIntMappingsLeftPartition(), splitNominalBinary.getIntMappingsRightPartition() };
            PMMLPredicate[] splitPredicates = new PMMLPredicate[2];
            for (int j = 0; j < splitPredicates.length; j++) {
                PMMLSimpleSetPredicate pred = null;
                pred = new PMMLSimpleSetPredicate(splitAttribute, PMMLSetOperator.IS_IN);
                pred.setArrayType(PMMLArrayType.STRING);
                LinkedHashSet<String> values = new LinkedHashSet<String>();
                for (int index : indices[j]) {
                    values.add(splitValues[index].toString());
                }
                pred.setValues(values);
                splitPredicates[j] = pred;
            }
            return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
        }
    }
}
Also used : LinkedHashSet(java.util.LinkedHashSet) ArrayList(java.util.ArrayList) SettingsModelString(org.knime.core.node.defaultnodesettings.SettingsModelString) DecisionTreeNodeSplitPMML(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML) PMMLSimpleSetPredicate(org.knime.base.node.mine.decisiontree2.PMMLSimpleSetPredicate) PMMLSimplePredicate(org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate) PMMLPredicate(org.knime.base.node.mine.decisiontree2.PMMLPredicate) DecisionTreeNodeLeaf(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf) DataCell(org.knime.core.data.DataCell) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Example 10 with PMMLSimpleSetPredicate

use of org.knime.base.node.mine.decisiontree2.PMMLSimpleSetPredicate in project knime-core by knime.

the class DecisionTreeLearnerNodeModel2 method buildTree.

/**
 * Recursively induces the decision tree.
 *
 * @param table the {@link InMemoryTable} representing the data for this
 *            node to determine the split and after that perform
 *            partitioning
 * @param exec the execution context for progress information
 * @param depth the current recursion depth
 */
private DecisionTreeNode buildTree(final InMemoryTable table, final ExecutionContext exec, final int depth, final SplitQualityMeasure splitQualityMeasure, final ParallelProcessing parallelProcessing, final int firstSplitCol) throws CanceledExecutionException, IllegalAccessException {
    exec.checkCanceled();
    // derive this node's id from the counter
    int nodeId = m_counter.getAndIncrement();
    DataCell majorityClass = table.getMajorityClassAsCell();
    LinkedHashMap<DataCell, Double> frequencies = table.getClassFrequencies();
    // if the distribution allows for a leaf
    if (table.isPureEnough()) {
        // free memory
        table.freeUnderlyingDataRows();
        double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
        exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
        return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
    } else {
        Split split = null;
        // find best split in specified column for first split
        if (depth == 0 && m_useFirstSplitCol.getBooleanValue()) {
            if (table.isNominal(firstSplitCol)) {
                if (m_binaryNominalSplitMode.getBooleanValue()) {
                    split = new SplitNominalBinary(table, firstSplitCol, splitQualityMeasure, m_minNumberRecordsPerNode.getIntValue(), m_maxNumNominalsForCompleteComputation.getIntValue());
                } else {
                    split = new SplitNominalNormal(table, firstSplitCol, splitQualityMeasure, m_minNumberRecordsPerNode.getIntValue());
                }
            } else {
                split = new SplitContinuous(table, firstSplitCol, splitQualityMeasure, m_averageSplitpoint.getBooleanValue(), m_minNumberRecordsPerNode.getIntValue());
            }
            if (Double.isNaN(split.getBestQualityMeasure()) || split.getBestQualityMeasure() == 0.0) {
                m_warningMessageSb.append("The specified root split column \"").append(split.getSplitAttributeName()).append("\" does not contain a valid split.");
            }
        }
        if (split == null) {
            // no root split column found or selected
            // find the best splits for all attributes
            SplitFinder splittFinder = new SplitFinder(table, splitQualityMeasure, m_averageSplitpoint.getBooleanValue(), m_minNumberRecordsPerNode.getIntValue(), m_binaryNominalSplitMode.getBooleanValue(), m_maxNumNominalsForCompleteComputation.getIntValue());
            // check for enough memory
            checkMemory();
            // get the best split among the best attribute splits
            split = splittFinder.getSplit();
        }
        // if no best split could be evaluated, create a leaf node
        if (split == null || !split.isValidSplit()) {
            table.freeUnderlyingDataRows();
            double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
            exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
            return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
        }
        // partition the attribute lists according to this split
        Partitioner partitioner = new Partitioner(table, split, m_minNumberRecordsPerNode.getIntValue());
        if (!partitioner.couldBeUsefulPartitioned()) {
            table.freeUnderlyingDataRows();
            double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
            exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
            return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
        }
        // get the just created partitions
        InMemoryTable[] partitionTables = partitioner.getPartitionTables();
        // recursively build the  child nodes
        DecisionTreeNode[] children = new DecisionTreeNode[partitionTables.length];
        ArrayList<ParallelBuilding> threads = new ArrayList<ParallelBuilding>();
        int i = 0;
        for (InMemoryTable partitionTable : partitionTables) {
            exec.checkCanceled();
            if (partitionTable.getNumberDataRows() * m_numberAttributes < 10000 || !parallelProcessing.isThreadAvailable()) {
                children[i] = buildTree(partitionTable, exec, depth + 1, splitQualityMeasure, parallelProcessing, firstSplitCol);
            } else {
                String threadName = "Build thread, node: " + nodeId + "." + i;
                ParallelBuilding buildThread = new ParallelBuilding(threadName, partitionTable, exec, depth + 1, i, splitQualityMeasure, parallelProcessing);
                LOGGER.debug("Start new parallel building thread: " + threadName);
                threads.add(buildThread);
                buildThread.start();
            }
            i++;
        }
        // already assigned to the child array
        for (ParallelBuilding buildThread : threads) {
            children[buildThread.getThreadIndex()] = buildThread.getResultNode();
            exec.checkCanceled();
            if (buildThread.getException() != null) {
                for (ParallelBuilding buildThread2 : threads) {
                    buildThread2.stop();
                }
                throw new RuntimeException(buildThread.getException().getMessage());
            }
        }
        threads.clear();
        if (split instanceof SplitContinuous) {
            double splitValue = ((SplitContinuous) split).getBestSplitValue();
            // return new DecisionTreeNodeSplitContinuous(nodeId,
            // majorityClass, frequencies, split
            // .getSplitAttributeName(), children, splitValue);
            String splitAttribute = split.getSplitAttributeName();
            PMMLPredicate[] splitPredicates = new PMMLPredicate[] { new PMMLSimplePredicate(splitAttribute, PMMLOperator.LESS_OR_EQUAL, Double.toString(splitValue)), new PMMLSimplePredicate(splitAttribute, PMMLOperator.GREATER_THAN, Double.toString(splitValue)) };
            return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
        } else if (split instanceof SplitNominalNormal) {
            // else the attribute is nominal
            DataCell[] splitValues = ((SplitNominalNormal) split).getSplitValues();
            // return new DecisionTreeNodeSplitNominal(nodeId, majorityClass,
            // frequencies, split.getSplitAttributeName(),
            // splitValues, children);
            int num = children.length;
            PMMLPredicate[] splitPredicates = new PMMLPredicate[num];
            String splitAttribute = split.getSplitAttributeName();
            for (int j = 0; j < num; j++) {
                splitPredicates[j] = new PMMLSimplePredicate(splitAttribute, PMMLOperator.EQUAL, splitValues[j].toString());
            }
            return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
        } else {
            // binary nominal
            SplitNominalBinary splitNominalBinary = (SplitNominalBinary) split;
            DataCell[] splitValues = splitNominalBinary.getSplitValues();
            // return new DecisionTreeNodeSplitNominalBinary(nodeId,
            // majorityClass, frequencies, split
            // .getSplitAttributeName(), splitValues,
            // splitNominalBinary.getIntMappingsLeftPartition(),
            // splitNominalBinary.getIntMappingsRightPartition(),
            // children/* children[0]=left, ..[1] right */);
            String splitAttribute = split.getSplitAttributeName();
            int[][] indices = new int[][] { splitNominalBinary.getIntMappingsLeftPartition(), splitNominalBinary.getIntMappingsRightPartition() };
            PMMLPredicate[] splitPredicates = new PMMLPredicate[2];
            for (int j = 0; j < splitPredicates.length; j++) {
                PMMLSimpleSetPredicate pred = null;
                pred = new PMMLSimpleSetPredicate(splitAttribute, PMMLSetOperator.IS_IN);
                pred.setArrayType(PMMLArrayType.STRING);
                LinkedHashSet<String> values = new LinkedHashSet<String>();
                for (int index : indices[j]) {
                    values.add(splitValues[index].toString());
                }
                pred.setValues(values);
                splitPredicates[j] = pred;
            }
            return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
        }
    }
}
Also used : LinkedHashSet(java.util.LinkedHashSet) ArrayList(java.util.ArrayList) SettingsModelString(org.knime.core.node.defaultnodesettings.SettingsModelString) DecisionTreeNodeSplitPMML(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML) PMMLSimpleSetPredicate(org.knime.base.node.mine.decisiontree2.PMMLSimpleSetPredicate) PMMLSimplePredicate(org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate) PMMLPredicate(org.knime.base.node.mine.decisiontree2.PMMLPredicate) DecisionTreeNodeLeaf(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf) DataCell(org.knime.core.data.DataCell) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Aggregations

PMMLSimpleSetPredicate (org.knime.base.node.mine.decisiontree2.PMMLSimpleSetPredicate)13 PMMLSimplePredicate (org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate)12 PMMLCompoundPredicate (org.knime.base.node.mine.decisiontree2.PMMLCompoundPredicate)9 PMMLPredicate (org.knime.base.node.mine.decisiontree2.PMMLPredicate)9 PMMLFalsePredicate (org.knime.base.node.mine.decisiontree2.PMMLFalsePredicate)8 PMMLTruePredicate (org.knime.base.node.mine.decisiontree2.PMMLTruePredicate)8 CompoundPredicate (org.dmg.pmml.CompoundPredicateDocument.CompoundPredicate)5 ArrayList (java.util.ArrayList)4 SimplePredicate (org.dmg.pmml.SimplePredicateDocument.SimplePredicate)4 SimpleSetPredicate (org.dmg.pmml.SimpleSetPredicateDocument.SimpleSetPredicate)4 DecisionTreeNode (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)4 DecisionTreeNodeSplitPMML (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML)4 DataCell (org.knime.core.data.DataCell)4 ArrayType (org.dmg.pmml.ArrayType)3 DecisionTreeNodeLeaf (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf)3 LinkedHashSet (java.util.LinkedHashSet)2 Entry (java.util.Map.Entry)2 ScoreDistribution (org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution)2 Enum (org.dmg.pmml.SimpleSetPredicateDocument.SimpleSetPredicate.BooleanOperator.Enum)2 PMMLArrayType (org.knime.base.node.mine.decisiontree2.PMMLArrayType)2