Search in sources :

Example 6 with ScoreDistribution

use of org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution in project knime-core by knime.

the class PMMLDecisionTreeTranslator method addTreeNode.

/**
 * A recursive function which converts each KNIME Tree node to a
 * corresponding PMML element.
 *
 * @param pmmlNode the desired PMML element
 * @param node A KNIME DecisionTree node
 */
private static void addTreeNode(final NodeDocument.Node pmmlNode, final DecisionTreeNode node, final DerivedFieldMapper mapper) {
    pmmlNode.setId(String.valueOf(node.getOwnIndex()));
    pmmlNode.setScore(node.getMajorityClass().toString());
    // read in and then exported again
    if (node.getEntireClassCount() > 0) {
        pmmlNode.setRecordCount(node.getEntireClassCount());
    }
    if (node instanceof DecisionTreeNodeSplitPMML) {
        int defaultChild = ((DecisionTreeNodeSplitPMML) node).getDefaultChildIndex();
        if (defaultChild > -1) {
            pmmlNode.setDefaultChild(String.valueOf(defaultChild));
        }
    }
    // adding score and stuff from parent
    DecisionTreeNode parent = node.getParent();
    if (parent == null) {
        // When the parent is null, it is the root Node.
        // For root node, the predicate is always True.
        pmmlNode.addNewTrue();
    } else if (parent instanceof DecisionTreeNodeSplitContinuous) {
        // SimplePredicate case
        DecisionTreeNodeSplitContinuous splitNode = (DecisionTreeNodeSplitContinuous) parent;
        if (splitNode.getIndex(node) == 0) {
            SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
            pmmlSimplePredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
            pmmlSimplePredicate.setOperator(Operator.LESS_OR_EQUAL);
            pmmlSimplePredicate.setValue(String.valueOf(splitNode.getThreshold()));
        } else if (splitNode.getIndex(node) == 1) {
            pmmlNode.addNewTrue();
        }
    } else if (parent instanceof DecisionTreeNodeSplitNominalBinary) {
        // SimpleSetPredicate case
        DecisionTreeNodeSplitNominalBinary splitNode = (DecisionTreeNodeSplitNominalBinary) parent;
        SimpleSetPredicate pmmlSimpleSetPredicate = pmmlNode.addNewSimpleSetPredicate();
        pmmlSimpleSetPredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
        pmmlSimpleSetPredicate.setBooleanOperator(SimpleSetPredicate.BooleanOperator.IS_IN);
        ArrayType pmmlArray = pmmlSimpleSetPredicate.addNewArray();
        pmmlArray.setType(ArrayType.Type.STRING);
        DataCell[] splitValues = splitNode.getSplitValues();
        List<Integer> indices = null;
        if (splitNode.getIndex(node) == SplitNominalBinary.LEFT_PARTITION) {
            indices = splitNode.getLeftChildIndices();
        } else if (splitNode.getIndex(node) == SplitNominalBinary.RIGHT_PARTITION) {
            indices = splitNode.getRightChildIndices();
        } else {
            throw new IllegalArgumentException("Split node is neither " + "contained in the right nor in the left partition.");
        }
        StringBuilder classSet = new StringBuilder();
        for (Integer i : indices) {
            if (classSet.length() > 0) {
                classSet.append(" ");
            }
            classSet.append(splitValues[i].toString());
        }
        pmmlArray.setN(BigInteger.valueOf(indices.size()));
        XmlCursor xmlCursor = pmmlArray.newCursor();
        xmlCursor.setTextValue(classSet.toString());
        xmlCursor.dispose();
    } else if (parent instanceof DecisionTreeNodeSplitNominal) {
        DecisionTreeNodeSplitNominal splitNode = (DecisionTreeNodeSplitNominal) parent;
        SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
        pmmlSimplePredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
        pmmlSimplePredicate.setOperator(Operator.EQUAL);
        int nodeIndex = parent.getIndex(node);
        pmmlSimplePredicate.setValue(String.valueOf(splitNode.getSplitValues()[nodeIndex].toString()));
    } else if (parent instanceof DecisionTreeNodeSplitPMML) {
        DecisionTreeNodeSplitPMML splitNode = (DecisionTreeNodeSplitPMML) parent;
        int nodeIndex = parent.getIndex(node);
        // get the PMML predicate of the current node from its parent
        PMMLPredicate predicate = splitNode.getSplitPred()[nodeIndex];
        if (predicate instanceof PMMLCompoundPredicate) {
            // surrogates as used in GBT
            exportCompoundPredicate(pmmlNode, (PMMLCompoundPredicate) predicate, mapper);
        } else {
            predicate.setSplitAttribute(mapper.getDerivedFieldName(predicate.getSplitAttribute()));
            // delegate the writing to the predicate translator
            PMMLPredicateTranslator.exportTo(predicate, pmmlNode);
        }
    } else {
        throw new IllegalArgumentException("Node Type " + parent.getClass() + " is not supported!");
    }
    // adding score distribution (class counts)
    Set<Entry<DataCell, Double>> classCounts = node.getClassCounts().entrySet();
    Iterator<Entry<DataCell, Double>> iterator = classCounts.iterator();
    while (iterator.hasNext()) {
        Entry<DataCell, Double> entry = iterator.next();
        DataCell cell = entry.getKey();
        Double freq = entry.getValue();
        ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
        pmmlScoreDist.setValue(cell.toString());
        pmmlScoreDist.setRecordCount(freq);
    }
    // adding children
    if (!(node instanceof DecisionTreeNodeLeaf)) {
        for (int i = 0; i < node.getChildCount(); i++) {
            addTreeNode(pmmlNode.addNewNode(), node.getChildAt(i), mapper);
        }
    }
}
Also used : DecisionTreeNodeSplitNominal(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitNominal) ArrayType(org.dmg.pmml.ArrayType) Entry(java.util.Map.Entry) DecisionTreeNodeSplitPMML(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML) SimplePredicate(org.dmg.pmml.SimplePredicateDocument.SimplePredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicateDocument.SimpleSetPredicate) XmlCursor(org.apache.xmlbeans.XmlCursor) BigInteger(java.math.BigInteger) ScoreDistribution(org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution) DecisionTreeNodeLeaf(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf) DecisionTreeNodeSplitNominalBinary(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitNominalBinary) DataCell(org.knime.core.data.DataCell) DecisionTreeNodeSplitContinuous(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitContinuous) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Example 7 with ScoreDistribution

use of org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution in project knime-core by knime.

the class PMMLDecisionTreeTranslator method getClassCount.

private LinkedHashMap<DataCell, Double> getClassCount(final Node node) {
    LinkedHashMap<DataCell, Double> knimeScoreDistribution = new LinkedHashMap<DataCell, Double>();
    ScoreDistribution[] pmmlScoreDistArray = node.getScoreDistributionArray();
    for (ScoreDistribution sd : pmmlScoreDistArray) {
        String category = sd.getValue();
        Double recordCount = sd.getRecordCount();
        knimeScoreDistribution.put(new StringCell(category), recordCount);
    }
    return knimeScoreDistribution;
}
Also used : ScoreDistribution(org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution) StringCell(org.knime.core.data.def.StringCell) DataCell(org.knime.core.data.DataCell) LinkedHashMap(java.util.LinkedHashMap)

Example 8 with ScoreDistribution

use of org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution in project knime-core by knime.

the class TreeModelPMMLTranslator method addTreeNode.

/**
 * @param pmmlNode
 * @param node
 */
private void addTreeNode(final Node pmmlNode, final TreeNodeClassification node) {
    int index = m_nodeIndex++;
    pmmlNode.setId(Integer.toString(index));
    pmmlNode.setScore(node.getMajorityClassName());
    double[] targetDistribution = node.getTargetDistribution();
    NominalValueRepresentation[] targetVals = node.getTargetMetaData().getValues();
    double sum = 0.0;
    for (Double v : targetDistribution) {
        sum += v;
    }
    pmmlNode.setRecordCount(sum);
    TreeNodeCondition condition = node.getCondition();
    if (condition instanceof TreeNodeTrueCondition) {
        pmmlNode.addNewTrue();
    } else if (condition instanceof TreeNodeColumnCondition) {
        final TreeNodeColumnCondition colCondition = (TreeNodeColumnCondition) condition;
        final String colName = colCondition.getColumnMetaData().getAttributeName();
        final Operator.Enum operator;
        final String value;
        if (condition instanceof TreeNodeNominalCondition) {
            final TreeNodeNominalCondition nominalCondition = (TreeNodeNominalCondition) condition;
            operator = Operator.EQUAL;
            value = nominalCondition.getValue();
        } else if (condition instanceof TreeNodeBitCondition) {
            final TreeNodeBitCondition bitCondition = (TreeNodeBitCondition) condition;
            operator = Operator.EQUAL;
            value = bitCondition.getValue() ? "1" : "0";
        } else if (condition instanceof TreeNodeNumericCondition) {
            final TreeNodeNumericCondition numCondition = (TreeNodeNumericCondition) condition;
            NumericOperator numOperator = numCondition.getNumericOperator();
            switch(numOperator) {
                case LargerThan:
                    operator = Operator.GREATER_THAN;
                    break;
                case LessThanOrEqual:
                    operator = Operator.LESS_OR_EQUAL;
                    break;
                default:
                    throw new IllegalStateException("Unsupported operator (not " + "implemented): " + numOperator);
            }
            value = Double.toString(numCondition.getSplitValue());
        } else {
            throw new IllegalStateException("Unsupported condition (not " + "implemented): " + condition.getClass().getSimpleName());
        }
        SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
        pmmlSimplePredicate.setField(colName);
        pmmlSimplePredicate.setOperator(operator);
        pmmlSimplePredicate.setValue(value);
    } else {
        throw new IllegalStateException("Unsupported condition (not " + "implemented): " + condition.getClass().getSimpleName());
    }
    // adding score distribution (class counts)
    for (int i = 0; i < targetDistribution.length; i++) {
        String className = targetVals[i].getNominalValue();
        double freq = targetDistribution[i];
        ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
        pmmlScoreDist.setValue(className);
        pmmlScoreDist.setRecordCount(freq);
    }
    for (int i = 0; i < node.getNrChildren(); i++) {
        addTreeNode(pmmlNode.addNewNode(), node.getChild(i));
    }
}
Also used : NominalValueRepresentation(org.knime.base.node.mine.treeensemble.data.NominalValueRepresentation) SimplePredicate(org.dmg.pmml.SimplePredicateDocument.SimplePredicate) ScoreDistribution(org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution) NumericOperator(org.knime.base.node.mine.treeensemble.model.TreeNodeNumericCondition.NumericOperator)

Aggregations

ScoreDistribution (org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution)8 SimplePredicate (org.dmg.pmml.SimplePredicateDocument.SimplePredicate)3 DataCell (org.knime.core.data.DataCell)3 Entry (java.util.Map.Entry)2 SimpleRule (org.dmg.pmml.SimpleRuleDocument.SimpleRule)2 SimpleSetPredicate (org.dmg.pmml.SimpleSetPredicateDocument.SimpleSetPredicate)2 DecisionTreeNode (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)2 DecisionTreeNodeSplitPMML (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML)2 NominalValueRepresentation (org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation)2 File (java.io.File)1 IOException (java.io.IOException)1 BigDecimal (java.math.BigDecimal)1 BigInteger (java.math.BigInteger)1 MathContext (java.math.MathContext)1 RoundingMode (java.math.RoundingMode)1 ArrayList (java.util.ArrayList)1 LinkedHashMap (java.util.LinkedHashMap)1 List (java.util.List)1 XmlCursor (org.apache.xmlbeans.XmlCursor)1 ArrayType (org.dmg.pmml.ArrayType)1