Search in sources :

Example 1 with SimpleRule

use of org.dmg.pmml.SimpleRuleDocument.SimpleRule in project knime-core by knime.

the class FromDecisionTreeNodeModel method addRules.

/**
 * Adds the rules to {@code rs} (recursively on each leaf).
 *
 * @param rs The output {@link RuleSet}.
 * @param parents The parent stack.
 * @param node The actual node.
 */
private void addRules(final RuleSet rs, final List<DecisionTreeNode> parents, final DecisionTreeNode node) {
    if (node.isLeaf()) {
        SimpleRule rule = rs.addNewSimpleRule();
        if (m_rulesToTable.getScorePmmlRecordCount().getBooleanValue()) {
            // This increases the PMML quite significantly
            BigDecimal sum = BigDecimal.ZERO;
            final MathContext mc = new MathContext(7, RoundingMode.HALF_EVEN);
            final boolean computeProbability = m_rulesToTable.getScorePmmlProbability().getBooleanValue();
            if (computeProbability) {
                sum = new BigDecimal(node.getClassCounts().entrySet().stream().mapToDouble(e -> e.getValue().doubleValue()).sum(), mc);
            }
            for (final Entry<DataCell, Double> entry : node.getClassCounts().entrySet()) {
                final ScoreDistribution scoreDistrib = rule.addNewScoreDistribution();
                scoreDistrib.setValue(entry.getKey().toString());
                scoreDistrib.setRecordCount(entry.getValue());
                if (computeProbability) {
                    if (Double.compare(entry.getValue().doubleValue(), 0.0) == 0) {
                        scoreDistrib.setProbability(new BigDecimal(0.0));
                    } else {
                        scoreDistrib.setProbability(new BigDecimal(entry.getValue().doubleValue(), mc).divide(sum, mc));
                    }
                }
            }
        }
        CompoundPredicate and = rule.addNewCompoundPredicate();
        and.setBooleanOperator(BooleanOperator.AND);
        DecisionTreeNode n = node;
        do {
            PMMLPredicate pmmlPredicate = ((DecisionTreeNodeSplitPMML) n.getParent()).getSplitPred()[n.getParent().getIndex(n)];
            if (pmmlPredicate instanceof PMMLSimplePredicate) {
                PMMLSimplePredicate simple = (PMMLSimplePredicate) pmmlPredicate;
                SimplePredicate predicate = and.addNewSimplePredicate();
                copy(predicate, simple);
            } else if (pmmlPredicate instanceof PMMLCompoundPredicate) {
                PMMLCompoundPredicate compound = (PMMLCompoundPredicate) pmmlPredicate;
                CompoundPredicate predicate = and.addNewCompoundPredicate();
                copy(predicate, compound);
            } else if (pmmlPredicate instanceof PMMLSimpleSetPredicate) {
                PMMLSimpleSetPredicate simpleSet = (PMMLSimpleSetPredicate) pmmlPredicate;
                copy(and.addNewSimpleSetPredicate(), simpleSet);
            } else if (pmmlPredicate instanceof PMMLTruePredicate) {
                and.addNewTrue();
            } else if (pmmlPredicate instanceof PMMLFalsePredicate) {
                and.addNewFalse();
            }
            n = n.getParent();
        } while (n.getParent() != null);
        // Simple fix for the case when a single condition was used.
        while (and.getFalseList().size() + and.getCompoundPredicateList().size() + and.getSimplePredicateList().size() + and.getSimpleSetPredicateList().size() + and.getTrueList().size() < 2) {
            and.addNewTrue();
        }
        if (m_rulesToTable.getProvideStatistics().getBooleanValue()) {
            rule.setNbCorrect(node.getOwnClassCount());
            rule.setRecordCount(node.getEntireClassCount());
        }
        rule.setScore(node.getMajorityClass().toString());
    } else {
        parents.add(node);
        for (int i = 0; i < node.getChildCount(); ++i) {
            addRules(rs, parents, node.getChildAt(i));
        }
        parents.remove(node);
    }
}
Also used : PMMLDocument(org.dmg.pmml.PMMLDocument) NodeSettingsRO(org.knime.core.node.NodeSettingsRO) InvalidSettingsException(org.knime.core.node.InvalidSettingsException) BooleanOperator(org.dmg.pmml.CompoundPredicateDocument.CompoundPredicate.BooleanOperator) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) PMMLMiningSchemaTranslator(org.knime.core.node.port.pmml.PMMLMiningSchemaTranslator) PMMLCompoundPredicate(org.knime.base.node.mine.decisiontree2.PMMLCompoundPredicate) PMMLSimplePredicate(org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate) BigDecimal(java.math.BigDecimal) PMML(org.dmg.pmml.PMMLDocument.PMML) PMMLFalsePredicate(org.knime.base.node.mine.decisiontree2.PMMLFalsePredicate) RuleSetToTable(org.knime.base.node.rules.engine.totable.RuleSetToTable) PMMLSimpleSetPredicate(org.knime.base.node.mine.decisiontree2.PMMLSimpleSetPredicate) RoundingMode(java.math.RoundingMode) PortType(org.knime.core.node.port.PortType) SimplePredicate(org.dmg.pmml.SimplePredicateDocument.SimplePredicate) ExecutionMonitor(org.knime.core.node.ExecutionMonitor) SimpleRule(org.dmg.pmml.SimpleRuleDocument.SimpleRule) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicateDocument.SimpleSetPredicate) CompoundPredicate(org.dmg.pmml.CompoundPredicateDocument.CompoundPredicate) MathContext(java.math.MathContext) NodeModel(org.knime.core.node.NodeModel) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode) List(java.util.List) BufferedDataTable(org.knime.core.node.BufferedDataTable) PMMLTruePredicate(org.knime.base.node.mine.decisiontree2.PMMLTruePredicate) PMMLDataDictionaryTranslator(org.knime.core.node.port.pmml.PMMLDataDictionaryTranslator) Entry(java.util.Map.Entry) PortObject(org.knime.core.node.port.PortObject) DecisionTreeNodeSplitPMML(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML) MININGFUNCTION(org.dmg.pmml.MININGFUNCTION) RuleSetModel(org.dmg.pmml.RuleSetModelDocument.RuleSetModel) PMMLOperator(org.knime.base.node.mine.decisiontree2.PMMLOperator) Criterion(org.dmg.pmml.RuleSelectionMethodDocument.RuleSelectionMethod.Criterion) RuleSet(org.dmg.pmml.RuleSetDocument.RuleSet) PMMLDecisionTreeTranslator(org.knime.base.node.mine.decisiontree2.PMMLDecisionTreeTranslator) Enum(org.dmg.pmml.SimplePredicateDocument.SimplePredicate.Operator.Enum) ArrayList(java.util.ArrayList) ExecutionContext(org.knime.core.node.ExecutionContext) PMMLPortObjectSpec(org.knime.core.node.port.pmml.PMMLPortObjectSpec) PMMLPredicate(org.knime.base.node.mine.decisiontree2.PMMLPredicate) DataCell(org.knime.core.data.DataCell) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) PMMLPortObjectSpecCreator(org.knime.core.node.port.pmml.PMMLPortObjectSpecCreator) PMMLPredicateTranslator(org.knime.base.node.mine.decisiontree2.PMMLPredicateTranslator) PortObjectSpec(org.knime.core.node.port.PortObjectSpec) IOException(java.io.IOException) File(java.io.File) NodeSettingsWO(org.knime.core.node.NodeSettingsWO) DecisionTree(org.knime.base.node.mine.decisiontree2.model.DecisionTree) ScoreDistribution(org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution) PMMLPredicate(org.knime.base.node.mine.decisiontree2.PMMLPredicate) PMMLFalsePredicate(org.knime.base.node.mine.decisiontree2.PMMLFalsePredicate) BigDecimal(java.math.BigDecimal) MathContext(java.math.MathContext) PMMLSimplePredicate(org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate) SimplePredicate(org.dmg.pmml.SimplePredicateDocument.SimplePredicate) PMMLCompoundPredicate(org.knime.base.node.mine.decisiontree2.PMMLCompoundPredicate) PMMLTruePredicate(org.knime.base.node.mine.decisiontree2.PMMLTruePredicate) ScoreDistribution(org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution) SimpleRule(org.dmg.pmml.SimpleRuleDocument.SimpleRule) PMMLSimpleSetPredicate(org.knime.base.node.mine.decisiontree2.PMMLSimpleSetPredicate) DataCell(org.knime.core.data.DataCell) PMMLSimplePredicate(org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate) PMMLCompoundPredicate(org.knime.base.node.mine.decisiontree2.PMMLCompoundPredicate) CompoundPredicate(org.dmg.pmml.CompoundPredicateDocument.CompoundPredicate) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Example 2 with SimpleRule

use of org.dmg.pmml.SimpleRuleDocument.SimpleRule in project knime-core by knime.

the class PMMLRuleEditorNodeModel method createRearranger.

/**
 * Creates the {@link ColumnRearranger} that can compute the new column.
 *
 * @param tableSpec The spec of the input table.
 * @param ruleSet The {@link RuleSet} xml object where the rules should be added.
 * @param parser The parser for the rules.
 * @return The {@link ColumnRearranger}.
 * @throws ParseException Problem during parsing.
 * @throws InvalidSettingsException if settings are invalid
 */
private ColumnRearranger createRearranger(final DataTableSpec tableSpec, final RuleSet ruleSet, final PMMLRuleParser parser) throws ParseException, InvalidSettingsException {
    if (m_settings.isAppendColumn() && m_settings.getNewColName().isEmpty()) {
        throw new InvalidSettingsException("No name for prediction column provided");
    }
    Set<String> outcomes = new LinkedHashSet<String>();
    List<DataType> outcomeTypes = new ArrayList<DataType>();
    int line = 0;
    final List<Pair<PMMLPredicate, Expression>> rules = new ArrayList<Pair<PMMLPredicate, Expression>>();
    for (String ruleText : m_settings.rules()) {
        ++line;
        if (RuleSupport.isComment(ruleText)) {
            continue;
        }
        try {
            ParseState state = new ParseState(ruleText);
            PMMLPredicate expression = parser.parseBooleanExpression(state);
            SimpleRule simpleRule = ruleSet.addNewSimpleRule();
            setCondition(simpleRule, expression);
            state.skipWS();
            state.consumeText("=>");
            state.skipWS();
            Expression outcome = parser.parseOutcomeOperand(state, null);
            // Only constants are allowed in the outcomes.
            assert outcome.isConstant() : outcome;
            rules.add(new Pair<PMMLPredicate, Expression>(expression, outcome));
            outcomeTypes.add(outcome.getOutputType());
            simpleRule.setScore(outcome.toString());
            // simpleRule.setConfidence(confidenceForRule(simpleRule, line, ruleText));
            simpleRule.setWeight(weightForRule(simpleRule, line, ruleText));
            outcomes.add(simpleRule.getScore());
        } catch (ParseException e) {
            throw Util.addContext(e, ruleText, line);
        }
    }
    DataType outcomeType = RuleEngineNodeModel.computeOutputType(outcomeTypes, true);
    ColumnRearranger rearranger = new ColumnRearranger(tableSpec);
    DataColumnSpecCreator specProto = new DataColumnSpecCreator(m_settings.isAppendColumn() ? DataTableSpec.getUniqueColumnName(tableSpec, m_settings.getNewColName()) : m_settings.getReplaceColumn(), outcomeType);
    specProto.setDomain(new DataColumnDomainCreator(toCells(outcomes, outcomeType)).createDomain());
    SingleCellFactory cellFactory = new SingleCellFactory(true, specProto.createSpec()) {

        @Override
        public DataCell getCell(final DataRow row) {
            for (Pair<PMMLPredicate, Expression> pair : rules) {
                if (pair.getFirst().evaluate(row, tableSpec) == Boolean.TRUE) {
                    return pair.getSecond().evaluate(row, null).getValue();
                }
            }
            return DataType.getMissingCell();
        }
    };
    if (m_settings.isAppendColumn()) {
        rearranger.append(cellFactory);
    } else {
        rearranger.replace(cellFactory, m_settings.getReplaceColumn());
    }
    return rearranger;
}
Also used : LinkedHashSet(java.util.LinkedHashSet) DataColumnSpecCreator(org.knime.core.data.DataColumnSpecCreator) ArrayList(java.util.ArrayList) DataColumnDomainCreator(org.knime.core.data.DataColumnDomainCreator) PMMLPredicate(org.knime.base.node.mine.decisiontree2.PMMLPredicate) ParseState(org.knime.base.node.rules.engine.BaseRuleParser.ParseState) DataRow(org.knime.core.data.DataRow) SimpleRule(org.dmg.pmml.SimpleRuleDocument.SimpleRule) ColumnRearranger(org.knime.core.data.container.ColumnRearranger) InvalidSettingsException(org.knime.core.node.InvalidSettingsException) Expression(org.knime.base.node.rules.engine.Expression) DataType(org.knime.core.data.DataType) ParseException(java.text.ParseException) SingleCellFactory(org.knime.core.data.container.SingleCellFactory) Pair(org.knime.core.util.Pair)

Example 3 with SimpleRule

use of org.dmg.pmml.SimpleRuleDocument.SimpleRule in project knime-core by knime.

the class PMMLRuleSetPredictorNodeModel method createRearranger.

/**
 * Constructs the {@link ColumnRearranger} for computing the new columns.
 *
 * @param obj The {@link PMMLPortObject} of the preprocessing model.
 * @param spec The {@link DataTableSpec} of the table.
 * @param replaceColumn Should replace the {@code outputColumnName}?
 * @param outputColumnName The output column name (which might be an existing).
 * @param addConfidence Should add the confidence values to a column?
 * @param confidenceColumnName The name of the confidence column.
 * @param validationColumnIdx Index of the validation column, {@code -1} if not specified.
 * @param processConcurrently Should be {@code false} when the statistics are to be computed.
 * @return The {@link ColumnRearranger} computing the result.
 * @throws InvalidSettingsException Problem with rules.
 */
private static ColumnRearranger createRearranger(final PMMLPortObject obj, final DataTableSpec spec, final boolean replaceColumn, final String outputColumnName, final boolean addConfidence, final String confidenceColumnName, final int validationColumnIdx, final boolean processConcurrently) throws InvalidSettingsException {
    List<Node> models = obj.getPMMLValue().getModels(PMMLModelType.RuleSetModel);
    if (models.size() != 1) {
        throw new InvalidSettingsException("Expected exactly on RuleSetModel, but got: " + models.size());
    }
    final PMMLRuleTranslator translator = new PMMLRuleTranslator();
    obj.initializeModelTranslator(translator);
    if (!translator.isScorable()) {
        throw new UnsupportedOperationException("The model is not scorable.");
    }
    final List<PMMLRuleTranslator.Rule> rules = translator.getRules();
    ColumnRearranger ret = new ColumnRearranger(spec);
    final List<DataColumnSpec> targetCols = obj.getSpec().getTargetCols();
    final DataType dataType = targetCols.isEmpty() ? StringCell.TYPE : targetCols.get(0).getType();
    DataColumnSpecCreator specCreator = new DataColumnSpecCreator(outputColumnName, dataType);
    Set<DataCell> outcomes = new LinkedHashSet<>();
    for (Rule rule : rules) {
        DataCell outcome;
        if (dataType.equals(BooleanCell.TYPE)) {
            outcome = BooleanCellFactory.create(rule.getOutcome());
        } else if (dataType.equals(StringCell.TYPE)) {
            outcome = new StringCell(rule.getOutcome());
        } else if (dataType.equals(DoubleCell.TYPE)) {
            try {
                outcome = new DoubleCell(Double.parseDouble(rule.getOutcome()));
            } catch (NumberFormatException e) {
                // ignore
                continue;
            }
        } else if (dataType.equals(IntCell.TYPE)) {
            try {
                outcome = new IntCell(Integer.parseInt(rule.getOutcome()));
            } catch (NumberFormatException e) {
                // ignore
                continue;
            }
        } else if (dataType.equals(LongCell.TYPE)) {
            try {
                outcome = new LongCell(Long.parseLong(rule.getOutcome()));
            } catch (NumberFormatException e) {
                // ignore
                continue;
            }
        } else {
            throw new UnsupportedOperationException("Unknown outcome type: " + dataType);
        }
        outcomes.add(outcome);
    }
    specCreator.setDomain(new DataColumnDomainCreator(outcomes).createDomain());
    DataColumnSpec colSpec = specCreator.createSpec();
    final RuleSelectionMethod ruleSelectionMethod = translator.getSelectionMethodList().get(0);
    final String defaultScore = translator.getDefaultScore();
    final Double defaultConfidence = translator.getDefaultConfidence();
    final DataColumnSpec[] specs;
    if (addConfidence) {
        specs = new DataColumnSpec[] { new DataColumnSpecCreator(DataTableSpec.getUniqueColumnName(ret.createSpec(), confidenceColumnName), DoubleCell.TYPE).createSpec(), colSpec };
    } else {
        specs = new DataColumnSpec[] { colSpec };
    }
    final int oldColumnIndex = replaceColumn ? ret.indexOf(outputColumnName) : -1;
    ret.append(new AbstractCellFactory(processConcurrently, specs) {

        private final List<String> m_values;

        {
            Map<String, List<String>> dd = translator.getDataDictionary();
            m_values = dd.get(targetCols.get(0).getName());
        }

        /**
         * {@inheritDoc}
         */
        @Override
        public DataCell[] getCells(final DataRow row) {
            // See http://www.dmg.org/v4-1/RuleSet.html#Rule
            switch(ruleSelectionMethod.getCriterion().intValue()) {
                case RuleSelectionMethod.Criterion.INT_FIRST_HIT:
                    {
                        Pair<DataCell, Double> resultAndConfidence = selectFirstHit(row);
                        return toCells(resultAndConfidence);
                    }
                case RuleSelectionMethod.Criterion.INT_WEIGHTED_MAX:
                    {
                        Pair<DataCell, Double> resultAndConfidence = selectWeightedMax(row);
                        return toCells(resultAndConfidence);
                    }
                case RuleSelectionMethod.Criterion.INT_WEIGHTED_SUM:
                    {
                        Pair<DataCell, Double> resultAndConfidence = selectWeightedSum(row);
                        return toCells(resultAndConfidence);
                    }
                default:
                    throw new UnsupportedOperationException(ruleSelectionMethod.getCriterion().toString());
            }
        }

        /**
         * Converts the pair to a {@link DataCell} array.
         *
         * @param resultAndConfidence The {@link Pair}.
         * @return The result and possibly the confidence.
         */
        private DataCell[] toCells(final Pair<DataCell, Double> resultAndConfidence) {
            if (!addConfidence) {
                return new DataCell[] { resultAndConfidence.getFirst() };
            }
            if (resultAndConfidence.getSecond() == null) {
                return new DataCell[] { DataType.getMissingCell(), resultAndConfidence.getFirst() };
            }
            return new DataCell[] { new DoubleCell(resultAndConfidence.getSecond()), resultAndConfidence.getFirst() };
        }

        /**
         * Computes the result and the confidence using the weighted sum method.
         *
         * @param row A {@link DataRow}
         * @return The result and the confidence.
         */
        private Pair<DataCell, Double> selectWeightedSum(final DataRow row) {
            final Map<String, Double> scoreToSumWeight = new LinkedHashMap<String, Double>();
            for (String val : m_values) {
                scoreToSumWeight.put(val, 0.0);
            }
            int matchedRuleCount = 0;
            for (final PMMLRuleTranslator.Rule rule : rules) {
                if (rule.getCondition().evaluate(row, spec) == Boolean.TRUE) {
                    ++matchedRuleCount;
                    Double sumWeight = scoreToSumWeight.get(rule.getOutcome());
                    if (sumWeight == null) {
                        throw new IllegalStateException("The score value: " + rule.getOutcome() + " is not in the data dictionary.");
                    }
                    final Double wRaw = rule.getWeight();
                    final double w = wRaw == null ? 0.0 : wRaw.doubleValue();
                    scoreToSumWeight.put(rule.getOutcome(), sumWeight + w);
                }
            }
            double maxSumWeight = Double.NEGATIVE_INFINITY;
            String bestScore = null;
            for (Entry<String, Double> entry : scoreToSumWeight.entrySet()) {
                final double d = entry.getValue().doubleValue();
                if (d > maxSumWeight) {
                    maxSumWeight = d;
                    bestScore = entry.getKey();
                }
            }
            if (bestScore == null || matchedRuleCount == 0) {
                return pair(result(defaultScore), defaultConfidence);
            }
            return pair(result(bestScore), maxSumWeight / matchedRuleCount);
        }

        /**
         * Helper method to create {@link Pair}s.
         *
         * @param f The first element.
         * @param s The second element.
         * @return The new pair.
         */
        private <F, S> Pair<F, S> pair(final F f, final S s) {
            return new Pair<F, S>(f, s);
        }

        /**
         * Computes the result and the confidence using the weighted max method.
         *
         * @param row A {@link DataRow}
         * @return The result and the confidence.
         */
        private Pair<DataCell, Double> selectWeightedMax(final DataRow row) {
            double maxWeight = Double.NEGATIVE_INFINITY;
            PMMLRuleTranslator.Rule bestRule = null;
            for (final PMMLRuleTranslator.Rule rule : rules) {
                if (rule.getCondition().evaluate(row, spec) == Boolean.TRUE) {
                    if (rule.getWeight() > maxWeight) {
                        maxWeight = rule.getWeight();
                        bestRule = rule;
                    }
                }
            }
            if (bestRule == null) {
                return pair(result(defaultScore), defaultConfidence);
            }
            bestRule.setRecordCount(bestRule.getRecordCount() + 1);
            DataCell result = result(bestRule);
            if (validationColumnIdx >= 0) {
                if (row.getCell(validationColumnIdx).equals(result)) {
                    bestRule.setNbCorrect(bestRule.getNbCorrect() + 1);
                }
            }
            Double confidence = bestRule.getConfidence();
            return pair(result, confidence == null ? defaultConfidence : confidence);
        }

        /**
         * Selects the outcome of the rule and converts it to the proper outcome type.
         *
         * @param rule A {@link Rule}.
         * @return The {@link DataCell} representing the result. (May be missing.)
         */
        private DataCell result(final PMMLRuleTranslator.Rule rule) {
            String outcome = rule.getOutcome();
            return result(outcome);
        }

        /**
         * Constructs the {@link DataCell} from its {@link String} representation ({@code outcome}) and its type.
         *
         * @param dataType The expected {@link DataType}
         * @param outcome The {@link String} representation.
         * @return The {@link DataCell}.
         */
        private DataCell result(final String outcome) {
            if (outcome == null) {
                return DataType.getMissingCell();
            }
            try {
                if (dataType.isCompatible(BooleanValue.class)) {
                    return BooleanCellFactory.create(outcome);
                }
                if (IntCell.TYPE.isASuperTypeOf(dataType)) {
                    return new IntCell(Integer.parseInt(outcome));
                }
                if (LongCell.TYPE.isASuperTypeOf(dataType)) {
                    return new LongCell(Long.parseLong(outcome));
                }
                if (DoubleCell.TYPE.isASuperTypeOf(dataType)) {
                    return new DoubleCell(Double.parseDouble(outcome));
                }
                return new StringCell(outcome);
            } catch (NumberFormatException e) {
                return new MissingCell(outcome + "\n" + e.getMessage());
            }
        }

        /**
         * Selects the first rule that matches and computes the confidence and result for the {@code row}.
         *
         * @param row A {@link DataRow}.
         * @return The result and the confidence.
         */
        private Pair<DataCell, Double> selectFirstHit(final DataRow row) {
            for (final PMMLRuleTranslator.Rule rule : rules) {
                Boolean eval = rule.getCondition().evaluate(row, spec);
                if (eval == Boolean.TRUE) {
                    rule.setRecordCount(rule.getRecordCount() + 1);
                    DataCell result = result(rule);
                    if (validationColumnIdx >= 0) {
                        if (row.getCell(validationColumnIdx).equals(result)) {
                            rule.setNbCorrect(rule.getNbCorrect() + 1);
                        }
                    }
                    Double confidence = rule.getConfidence();
                    return pair(result, confidence == null ? defaultConfidence : confidence);
                }
            }
            return pair(result(defaultScore), defaultConfidence);
        }

        /**
         * {@inheritDoc}
         */
        @Override
        public void afterProcessing() {
            super.afterProcessing();
            obj.getPMMLValue();
            RuleSetModel ruleSet = translator.getOriginalRuleSetModel();
            assert rules.size() == ruleSet.getRuleSet().getSimpleRuleList().size() + ruleSet.getRuleSet().getCompoundRuleList().size();
            if (ruleSet.getRuleSet().getSimpleRuleList().size() == rules.size()) {
                for (int i = 0; i < rules.size(); ++i) {
                    Rule rule = rules.get(i);
                    final SimpleRule simpleRuleArray = ruleSet.getRuleSet().getSimpleRuleArray(i);
                    synchronized (simpleRuleArray) /*synchronized fixes AP-6766 */
                    {
                        simpleRuleArray.setRecordCount(rule.getRecordCount());
                        if (validationColumnIdx >= 0) {
                            simpleRuleArray.setNbCorrect(rule.getNbCorrect());
                        } else if (simpleRuleArray.isSetNbCorrect()) {
                            simpleRuleArray.unsetNbCorrect();
                        }
                    }
                }
            }
        }
    });
    if (replaceColumn) {
        ret.remove(outputColumnName);
        ret.move(ret.getColumnCount() - 1 - (addConfidence ? 1 : 0), oldColumnIndex);
    }
    return ret;
}
Also used : LinkedHashSet(java.util.LinkedHashSet) RuleSetModel(org.dmg.pmml.RuleSetModelDocument.RuleSetModel) DataColumnSpecCreator(org.knime.core.data.DataColumnSpecCreator) DoubleCell(org.knime.core.data.def.DoubleCell) Node(org.w3c.dom.Node) SettingsModelString(org.knime.core.node.defaultnodesettings.SettingsModelString) DataRow(org.knime.core.data.DataRow) IntCell(org.knime.core.data.def.IntCell) Entry(java.util.Map.Entry) SimpleRule(org.dmg.pmml.SimpleRuleDocument.SimpleRule) ColumnRearranger(org.knime.core.data.container.ColumnRearranger) DataColumnSpec(org.knime.core.data.DataColumnSpec) BooleanValue(org.knime.core.data.BooleanValue) DataType(org.knime.core.data.DataType) SettingsModelBoolean(org.knime.core.node.defaultnodesettings.SettingsModelBoolean) Pair(org.knime.core.util.Pair) AbstractCellFactory(org.knime.core.data.container.AbstractCellFactory) DataColumnDomainCreator(org.knime.core.data.DataColumnDomainCreator) RuleSelectionMethod(org.dmg.pmml.RuleSelectionMethodDocument.RuleSelectionMethod) Rule(org.knime.base.node.rules.engine.pmml.PMMLRuleTranslator.Rule) LongCell(org.knime.core.data.def.LongCell) InvalidSettingsException(org.knime.core.node.InvalidSettingsException) StringCell(org.knime.core.data.def.StringCell) MissingCell(org.knime.core.data.MissingCell) DataCell(org.knime.core.data.DataCell) SimpleRule(org.dmg.pmml.SimpleRuleDocument.SimpleRule) Rule(org.knime.base.node.rules.engine.pmml.PMMLRuleTranslator.Rule) Map(java.util.Map) LinkedHashMap(java.util.LinkedHashMap)

Example 4 with SimpleRule

use of org.dmg.pmml.SimpleRuleDocument.SimpleRule in project knime-core by knime.

the class PMMLRuleTranslator method findFirst.

/**
 * Finds the first xml {@link SimpleRule} within the {@code rule} {@link CompoundRule}.
 *
 * @param rule A {@link CompoundRule}.
 * @return The first {@link SimpleRule} the should provide the outcome.
 */
private SimpleRule findFirst(final CompoundRule rule) {
    XmlCursor newCursor = rule.newCursor();
    if (newCursor.toFirstChild()) {
        do {
            XmlObject object = newCursor.getObject();
            if (object instanceof SimpleRuleDocument.SimpleRule) {
                SimpleRuleDocument.SimpleRule sr = (SimpleRuleDocument.SimpleRule) object;
                return sr;
            }
            if (object instanceof CompoundRule) {
                CompoundRule cp = (CompoundRule) object;
                SimpleRule first = findFirst(cp);
                if (first != null) {
                    return first;
                }
            }
        } while (newCursor.toNextSibling());
    }
    assert false : rule;
    return null;
}
Also used : CompoundRule(org.dmg.pmml.CompoundRuleDocument.CompoundRule) SimpleRule(org.dmg.pmml.SimpleRuleDocument.SimpleRule) SimpleRule(org.dmg.pmml.SimpleRuleDocument.SimpleRule) XmlObject(org.apache.xmlbeans.XmlObject) SimpleRuleDocument(org.dmg.pmml.SimpleRuleDocument) XmlCursor(org.apache.xmlbeans.XmlCursor)

Example 5 with SimpleRule

use of org.dmg.pmml.SimpleRuleDocument.SimpleRule in project knime-core by knime.

the class PMMLRuleTranslator method addRules.

/**
 * Adds the {@code rules} as {@link SimpleRule}s to {@code ruleSet}.
 *
 * @param ruleSet An xml {@link RuleSet}.
 * @param rules The simplified {@link Rule}s to add.
 */
private void addRules(final RuleSet ruleSet, final List<Rule> rules) {
    for (Rule rule : rules) {
        SimpleRule simpleRule = ruleSet.addNewSimpleRule();
        simpleRule.setScore(rule.getOutcome());
        if (m_provideStatistics && !Double.isNaN(rule.getNbCorrect())) {
            simpleRule.setNbCorrect(rule.getNbCorrect());
        }
        if (m_provideStatistics && !Double.isNaN(rule.getRecordCount())) {
            simpleRule.setRecordCount(rule.getRecordCount());
        }
        setPredicate(simpleRule, rule.getCondition());
        if (rule.getWeight() != null) {
            simpleRule.setWeight(rule.getWeight());
        }
        if (rule.getConfidence() != null) {
            simpleRule.setConfidence(rule.getConfidence());
        }
        for (final Entry<String, ScoreProbabilityAndRecordCount> entry : rule.getScoreDistribution().entrySet()) {
            final ScoreDistribution sd = simpleRule.addNewScoreDistribution();
            sd.setValue(entry.getKey());
            final ScoreProbabilityAndRecordCount value = entry.getValue();
            if (value.getProbability() != null) {
                sd.setProbability(value.getProbability());
            }
            sd.setRecordCount(value.getRecordCount());
        }
    }
}
Also used : ScoreDistribution(org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution) SimpleRule(org.dmg.pmml.SimpleRuleDocument.SimpleRule) CompoundRule(org.dmg.pmml.CompoundRuleDocument.CompoundRule) SimpleRule(org.dmg.pmml.SimpleRuleDocument.SimpleRule)

Aggregations

SimpleRule (org.dmg.pmml.SimpleRuleDocument.SimpleRule)9 PMMLPredicate (org.knime.base.node.mine.decisiontree2.PMMLPredicate)6 ArrayList (java.util.ArrayList)5 CompoundRule (org.dmg.pmml.CompoundRuleDocument.CompoundRule)5 RuleSetModel (org.dmg.pmml.RuleSetModelDocument.RuleSetModel)4 PMMLCompoundPredicate (org.knime.base.node.mine.decisiontree2.PMMLCompoundPredicate)4 InvalidSettingsException (org.knime.core.node.InvalidSettingsException)4 Entry (java.util.Map.Entry)3 XmlCursor (org.apache.xmlbeans.XmlCursor)3 XmlObject (org.apache.xmlbeans.XmlObject)3 CompoundPredicate (org.dmg.pmml.CompoundPredicateDocument.CompoundPredicate)3 PMMLDocument (org.dmg.pmml.PMMLDocument)3 PMML (org.dmg.pmml.PMMLDocument.PMML)3 RuleSet (org.dmg.pmml.RuleSetDocument.RuleSet)3 ScoreDistribution (org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution)3 SimplePredicate (org.dmg.pmml.SimplePredicateDocument.SimplePredicate)3 PMMLSimplePredicate (org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate)3 DataCell (org.knime.core.data.DataCell)3 DataColumnSpecCreator (org.knime.core.data.DataColumnSpecCreator)3 DataRow (org.knime.core.data.DataRow)3