Search in sources :

Example 1 with ScoreDistribution

use of org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution in project knime-core by knime.

the class FromDecisionTreeNodeModel method addRules.

/**
 * Adds the rules to {@code rs} (recursively on each leaf).
 *
 * @param rs The output {@link RuleSet}.
 * @param parents The parent stack.
 * @param node The actual node.
 */
private void addRules(final RuleSet rs, final List<DecisionTreeNode> parents, final DecisionTreeNode node) {
    if (node.isLeaf()) {
        SimpleRule rule = rs.addNewSimpleRule();
        if (m_rulesToTable.getScorePmmlRecordCount().getBooleanValue()) {
            // This increases the PMML quite significantly
            BigDecimal sum = BigDecimal.ZERO;
            final MathContext mc = new MathContext(7, RoundingMode.HALF_EVEN);
            final boolean computeProbability = m_rulesToTable.getScorePmmlProbability().getBooleanValue();
            if (computeProbability) {
                sum = new BigDecimal(node.getClassCounts().entrySet().stream().mapToDouble(e -> e.getValue().doubleValue()).sum(), mc);
            }
            for (final Entry<DataCell, Double> entry : node.getClassCounts().entrySet()) {
                final ScoreDistribution scoreDistrib = rule.addNewScoreDistribution();
                scoreDistrib.setValue(entry.getKey().toString());
                scoreDistrib.setRecordCount(entry.getValue());
                if (computeProbability) {
                    if (Double.compare(entry.getValue().doubleValue(), 0.0) == 0) {
                        scoreDistrib.setProbability(new BigDecimal(0.0));
                    } else {
                        scoreDistrib.setProbability(new BigDecimal(entry.getValue().doubleValue(), mc).divide(sum, mc));
                    }
                }
            }
        }
        CompoundPredicate and = rule.addNewCompoundPredicate();
        and.setBooleanOperator(BooleanOperator.AND);
        DecisionTreeNode n = node;
        do {
            PMMLPredicate pmmlPredicate = ((DecisionTreeNodeSplitPMML) n.getParent()).getSplitPred()[n.getParent().getIndex(n)];
            if (pmmlPredicate instanceof PMMLSimplePredicate) {
                PMMLSimplePredicate simple = (PMMLSimplePredicate) pmmlPredicate;
                SimplePredicate predicate = and.addNewSimplePredicate();
                copy(predicate, simple);
            } else if (pmmlPredicate instanceof PMMLCompoundPredicate) {
                PMMLCompoundPredicate compound = (PMMLCompoundPredicate) pmmlPredicate;
                CompoundPredicate predicate = and.addNewCompoundPredicate();
                copy(predicate, compound);
            } else if (pmmlPredicate instanceof PMMLSimpleSetPredicate) {
                PMMLSimpleSetPredicate simpleSet = (PMMLSimpleSetPredicate) pmmlPredicate;
                copy(and.addNewSimpleSetPredicate(), simpleSet);
            } else if (pmmlPredicate instanceof PMMLTruePredicate) {
                and.addNewTrue();
            } else if (pmmlPredicate instanceof PMMLFalsePredicate) {
                and.addNewFalse();
            }
            n = n.getParent();
        } while (n.getParent() != null);
        // Simple fix for the case when a single condition was used.
        while (and.getFalseList().size() + and.getCompoundPredicateList().size() + and.getSimplePredicateList().size() + and.getSimpleSetPredicateList().size() + and.getTrueList().size() < 2) {
            and.addNewTrue();
        }
        if (m_rulesToTable.getProvideStatistics().getBooleanValue()) {
            rule.setNbCorrect(node.getOwnClassCount());
            rule.setRecordCount(node.getEntireClassCount());
        }
        rule.setScore(node.getMajorityClass().toString());
    } else {
        parents.add(node);
        for (int i = 0; i < node.getChildCount(); ++i) {
            addRules(rs, parents, node.getChildAt(i));
        }
        parents.remove(node);
    }
}
Also used : PMMLDocument(org.dmg.pmml.PMMLDocument) NodeSettingsRO(org.knime.core.node.NodeSettingsRO) InvalidSettingsException(org.knime.core.node.InvalidSettingsException) BooleanOperator(org.dmg.pmml.CompoundPredicateDocument.CompoundPredicate.BooleanOperator) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) PMMLMiningSchemaTranslator(org.knime.core.node.port.pmml.PMMLMiningSchemaTranslator) PMMLCompoundPredicate(org.knime.base.node.mine.decisiontree2.PMMLCompoundPredicate) PMMLSimplePredicate(org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate) BigDecimal(java.math.BigDecimal) PMML(org.dmg.pmml.PMMLDocument.PMML) PMMLFalsePredicate(org.knime.base.node.mine.decisiontree2.PMMLFalsePredicate) RuleSetToTable(org.knime.base.node.rules.engine.totable.RuleSetToTable) PMMLSimpleSetPredicate(org.knime.base.node.mine.decisiontree2.PMMLSimpleSetPredicate) RoundingMode(java.math.RoundingMode) PortType(org.knime.core.node.port.PortType) SimplePredicate(org.dmg.pmml.SimplePredicateDocument.SimplePredicate) ExecutionMonitor(org.knime.core.node.ExecutionMonitor) SimpleRule(org.dmg.pmml.SimpleRuleDocument.SimpleRule) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicateDocument.SimpleSetPredicate) CompoundPredicate(org.dmg.pmml.CompoundPredicateDocument.CompoundPredicate) MathContext(java.math.MathContext) NodeModel(org.knime.core.node.NodeModel) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode) List(java.util.List) BufferedDataTable(org.knime.core.node.BufferedDataTable) PMMLTruePredicate(org.knime.base.node.mine.decisiontree2.PMMLTruePredicate) PMMLDataDictionaryTranslator(org.knime.core.node.port.pmml.PMMLDataDictionaryTranslator) Entry(java.util.Map.Entry) PortObject(org.knime.core.node.port.PortObject) DecisionTreeNodeSplitPMML(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML) MININGFUNCTION(org.dmg.pmml.MININGFUNCTION) RuleSetModel(org.dmg.pmml.RuleSetModelDocument.RuleSetModel) PMMLOperator(org.knime.base.node.mine.decisiontree2.PMMLOperator) Criterion(org.dmg.pmml.RuleSelectionMethodDocument.RuleSelectionMethod.Criterion) RuleSet(org.dmg.pmml.RuleSetDocument.RuleSet) PMMLDecisionTreeTranslator(org.knime.base.node.mine.decisiontree2.PMMLDecisionTreeTranslator) Enum(org.dmg.pmml.SimplePredicateDocument.SimplePredicate.Operator.Enum) ArrayList(java.util.ArrayList) ExecutionContext(org.knime.core.node.ExecutionContext) PMMLPortObjectSpec(org.knime.core.node.port.pmml.PMMLPortObjectSpec) PMMLPredicate(org.knime.base.node.mine.decisiontree2.PMMLPredicate) DataCell(org.knime.core.data.DataCell) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) PMMLPortObjectSpecCreator(org.knime.core.node.port.pmml.PMMLPortObjectSpecCreator) PMMLPredicateTranslator(org.knime.base.node.mine.decisiontree2.PMMLPredicateTranslator) PortObjectSpec(org.knime.core.node.port.PortObjectSpec) IOException(java.io.IOException) File(java.io.File) NodeSettingsWO(org.knime.core.node.NodeSettingsWO) DecisionTree(org.knime.base.node.mine.decisiontree2.model.DecisionTree) ScoreDistribution(org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution) PMMLPredicate(org.knime.base.node.mine.decisiontree2.PMMLPredicate) PMMLFalsePredicate(org.knime.base.node.mine.decisiontree2.PMMLFalsePredicate) BigDecimal(java.math.BigDecimal) MathContext(java.math.MathContext) PMMLSimplePredicate(org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate) SimplePredicate(org.dmg.pmml.SimplePredicateDocument.SimplePredicate) PMMLCompoundPredicate(org.knime.base.node.mine.decisiontree2.PMMLCompoundPredicate) PMMLTruePredicate(org.knime.base.node.mine.decisiontree2.PMMLTruePredicate) ScoreDistribution(org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution) SimpleRule(org.dmg.pmml.SimpleRuleDocument.SimpleRule) PMMLSimpleSetPredicate(org.knime.base.node.mine.decisiontree2.PMMLSimpleSetPredicate) DataCell(org.knime.core.data.DataCell) PMMLSimplePredicate(org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate) PMMLCompoundPredicate(org.knime.base.node.mine.decisiontree2.PMMLCompoundPredicate) CompoundPredicate(org.dmg.pmml.CompoundPredicateDocument.CompoundPredicate) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Example 2 with ScoreDistribution

use of org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution in project knime-core by knime.

the class PMMLRuleTranslator method addRules.

/**
 * Adds the {@code rules} as {@link SimpleRule}s to {@code ruleSet}.
 *
 * @param ruleSet An xml {@link RuleSet}.
 * @param rules The simplified {@link Rule}s to add.
 */
private void addRules(final RuleSet ruleSet, final List<Rule> rules) {
    for (Rule rule : rules) {
        SimpleRule simpleRule = ruleSet.addNewSimpleRule();
        simpleRule.setScore(rule.getOutcome());
        if (m_provideStatistics && !Double.isNaN(rule.getNbCorrect())) {
            simpleRule.setNbCorrect(rule.getNbCorrect());
        }
        if (m_provideStatistics && !Double.isNaN(rule.getRecordCount())) {
            simpleRule.setRecordCount(rule.getRecordCount());
        }
        setPredicate(simpleRule, rule.getCondition());
        if (rule.getWeight() != null) {
            simpleRule.setWeight(rule.getWeight());
        }
        if (rule.getConfidence() != null) {
            simpleRule.setConfidence(rule.getConfidence());
        }
        for (final Entry<String, ScoreProbabilityAndRecordCount> entry : rule.getScoreDistribution().entrySet()) {
            final ScoreDistribution sd = simpleRule.addNewScoreDistribution();
            sd.setValue(entry.getKey());
            final ScoreProbabilityAndRecordCount value = entry.getValue();
            if (value.getProbability() != null) {
                sd.setProbability(value.getProbability());
            }
            sd.setRecordCount(value.getRecordCount());
        }
    }
}
Also used : ScoreDistribution(org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution) SimpleRule(org.dmg.pmml.SimpleRuleDocument.SimpleRule) CompoundRule(org.dmg.pmml.CompoundRuleDocument.CompoundRule) SimpleRule(org.dmg.pmml.SimpleRuleDocument.SimpleRule)

Example 3 with ScoreDistribution

use of org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution in project knime-core by knime.

the class ClassificationContentParser method parseDistribution.

private float[] parseDistribution(final Node node) {
    List<ScoreDistribution> list = node.getScoreDistributionList();
    float[] targetDistribution = new float[list.size()];
    for (ScoreDistribution x : list) {
        int idx = m_nomValMapper.getRepresentation(x.getValue()).getAssignedInteger();
        targetDistribution[idx] = (float) x.getRecordCount();
    }
    return targetDistribution;
}
Also used : ScoreDistribution(org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution)

Example 4 with ScoreDistribution

use of org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution in project knime-core by knime.

the class ClassificationTreeModelExporter method addNodeContent.

/**
 * {@inheritDoc}
 */
@Override
protected void addNodeContent(final int nodeId, final Node pmmlNode, final TreeNodeClassification node) {
    pmmlNode.setScore(node.getMajorityClassName());
    float[] targetDistribution = node.getTargetDistribution();
    NominalValueRepresentation[] targetVals = node.getTargetMetaData().getValues();
    double sum = 0.0;
    for (float v : targetDistribution) {
        sum += v;
    }
    pmmlNode.setRecordCount(sum);
    // adding score distribution (class counts)
    for (int i = 0; i < targetDistribution.length; i++) {
        String className = targetVals[i].getNominalValue();
        double freq = targetDistribution[i];
        ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
        pmmlScoreDist.setValue(className);
        pmmlScoreDist.setRecordCount(freq);
    }
}
Also used : ScoreDistribution(org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution) NominalValueRepresentation(org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation)

Example 5 with ScoreDistribution

use of org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution in project knime-core by knime.

the class TreeModelPMMLTranslator method addTreeNode.

/**
 * @param pmmlNode
 * @param node
 */
private void addTreeNode(final Node pmmlNode, final AbstractTreeNode node) {
    int index = m_nodeIndex++;
    pmmlNode.setId(Integer.toString(index));
    if (node instanceof TreeNodeClassification) {
        final TreeNodeClassification clazNode = (TreeNodeClassification) node;
        pmmlNode.setScore(clazNode.getMajorityClassName());
        float[] targetDistribution = clazNode.getTargetDistribution();
        NominalValueRepresentation[] targetVals = clazNode.getTargetMetaData().getValues();
        double sum = 0.0;
        for (Float v : targetDistribution) {
            sum += v;
        }
        pmmlNode.setRecordCount(sum);
        // adding score distribution (class counts)
        for (int i = 0; i < targetDistribution.length; i++) {
            String className = targetVals[i].getNominalValue();
            double freq = targetDistribution[i];
            ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
            pmmlScoreDist.setValue(className);
            pmmlScoreDist.setRecordCount(freq);
        }
    } else if (node instanceof TreeNodeRegression) {
        final TreeNodeRegression regNode = (TreeNodeRegression) node;
        pmmlNode.setScore(Double.toString(regNode.getMean()));
    }
    TreeNodeCondition condition = node.getCondition();
    if (condition instanceof TreeNodeTrueCondition) {
        pmmlNode.addNewTrue();
    } else if (condition instanceof TreeNodeColumnCondition) {
        final TreeNodeColumnCondition colCondition = (TreeNodeColumnCondition) condition;
        handleColumnCondition(colCondition, pmmlNode);
    } else if (condition instanceof AbstractTreeNodeSurrogateCondition) {
        final AbstractTreeNodeSurrogateCondition surrogateCond = (AbstractTreeNodeSurrogateCondition) condition;
        setValuesFromPMMLCompoundPredicate(pmmlNode.addNewCompoundPredicate(), surrogateCond.toPMMLPredicate());
    } else {
        throw new IllegalStateException("Unsupported condition (not " + "implemented): " + condition.getClass().getSimpleName());
    }
    for (int i = 0; i < node.getNrChildren(); i++) {
        addTreeNode(pmmlNode.addNewNode(), node.getChild(i));
    }
}
Also used : NominalValueRepresentation(org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation) ScoreDistribution(org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution)

Aggregations

ScoreDistribution (org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution)8 SimplePredicate (org.dmg.pmml.SimplePredicateDocument.SimplePredicate)3 DataCell (org.knime.core.data.DataCell)3 Entry (java.util.Map.Entry)2 SimpleRule (org.dmg.pmml.SimpleRuleDocument.SimpleRule)2 SimpleSetPredicate (org.dmg.pmml.SimpleSetPredicateDocument.SimpleSetPredicate)2 DecisionTreeNode (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)2 DecisionTreeNodeSplitPMML (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML)2 NominalValueRepresentation (org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation)2 File (java.io.File)1 IOException (java.io.IOException)1 BigDecimal (java.math.BigDecimal)1 BigInteger (java.math.BigInteger)1 MathContext (java.math.MathContext)1 RoundingMode (java.math.RoundingMode)1 ArrayList (java.util.ArrayList)1 LinkedHashMap (java.util.LinkedHashMap)1 List (java.util.List)1 XmlCursor (org.apache.xmlbeans.XmlCursor)1 ArrayType (org.dmg.pmml.ArrayType)1