use of org.dmg.pmml.RuleSetDocument.RuleSet in project knime-core by knime.
the class FromDecisionTreeNodeModel method addRules.
/**
* Adds the rules to {@code rs} (recursively on each leaf).
*
* @param rs The output {@link RuleSet}.
* @param parents The parent stack.
* @param node The actual node.
*/
private void addRules(final RuleSet rs, final List<DecisionTreeNode> parents, final DecisionTreeNode node) {
if (node.isLeaf()) {
SimpleRule rule = rs.addNewSimpleRule();
if (m_rulesToTable.getScorePmmlRecordCount().getBooleanValue()) {
// This increases the PMML quite significantly
BigDecimal sum = BigDecimal.ZERO;
final MathContext mc = new MathContext(7, RoundingMode.HALF_EVEN);
final boolean computeProbability = m_rulesToTable.getScorePmmlProbability().getBooleanValue();
if (computeProbability) {
sum = new BigDecimal(node.getClassCounts().entrySet().stream().mapToDouble(e -> e.getValue().doubleValue()).sum(), mc);
}
for (final Entry<DataCell, Double> entry : node.getClassCounts().entrySet()) {
final ScoreDistribution scoreDistrib = rule.addNewScoreDistribution();
scoreDistrib.setValue(entry.getKey().toString());
scoreDistrib.setRecordCount(entry.getValue());
if (computeProbability) {
if (Double.compare(entry.getValue().doubleValue(), 0.0) == 0) {
scoreDistrib.setProbability(new BigDecimal(0.0));
} else {
scoreDistrib.setProbability(new BigDecimal(entry.getValue().doubleValue(), mc).divide(sum, mc));
}
}
}
}
CompoundPredicate and = rule.addNewCompoundPredicate();
and.setBooleanOperator(BooleanOperator.AND);
DecisionTreeNode n = node;
do {
PMMLPredicate pmmlPredicate = ((DecisionTreeNodeSplitPMML) n.getParent()).getSplitPred()[n.getParent().getIndex(n)];
if (pmmlPredicate instanceof PMMLSimplePredicate) {
PMMLSimplePredicate simple = (PMMLSimplePredicate) pmmlPredicate;
SimplePredicate predicate = and.addNewSimplePredicate();
copy(predicate, simple);
} else if (pmmlPredicate instanceof PMMLCompoundPredicate) {
PMMLCompoundPredicate compound = (PMMLCompoundPredicate) pmmlPredicate;
CompoundPredicate predicate = and.addNewCompoundPredicate();
copy(predicate, compound);
} else if (pmmlPredicate instanceof PMMLSimpleSetPredicate) {
PMMLSimpleSetPredicate simpleSet = (PMMLSimpleSetPredicate) pmmlPredicate;
copy(and.addNewSimpleSetPredicate(), simpleSet);
} else if (pmmlPredicate instanceof PMMLTruePredicate) {
and.addNewTrue();
} else if (pmmlPredicate instanceof PMMLFalsePredicate) {
and.addNewFalse();
}
n = n.getParent();
} while (n.getParent() != null);
// Simple fix for the case when a single condition was used.
while (and.getFalseList().size() + and.getCompoundPredicateList().size() + and.getSimplePredicateList().size() + and.getSimpleSetPredicateList().size() + and.getTrueList().size() < 2) {
and.addNewTrue();
}
if (m_rulesToTable.getProvideStatistics().getBooleanValue()) {
rule.setNbCorrect(node.getOwnClassCount());
rule.setRecordCount(node.getEntireClassCount());
}
rule.setScore(node.getMajorityClass().toString());
} else {
parents.add(node);
for (int i = 0; i < node.getChildCount(); ++i) {
addRules(rs, parents, node.getChildAt(i));
}
parents.remove(node);
}
}
use of org.dmg.pmml.RuleSetDocument.RuleSet in project knime-core by knime.
the class FromDecisionTreeNodeModel method execute.
/**
* {@inheritDoc}
* @throws CanceledExecutionException Execution cancelled.
* @throws InvalidSettingsException No or more than one RuleSet model is in the PMML input.
*/
@Override
protected PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws CanceledExecutionException, InvalidSettingsException {
PMMLPortObject decTreeModel = (PMMLPortObject) inData[0];
PMMLDecisionTreeTranslator treeTranslator = new PMMLDecisionTreeTranslator();
decTreeModel.initializeModelTranslator(treeTranslator);
DecisionTree decisionTree = treeTranslator.getDecisionTree();
decisionTree.getRootNode();
PMMLPortObject ruleSetModel = new PMMLPortObject(decTreeModel.getSpec());
PMMLDocument document = PMMLDocument.Factory.newInstance();
PMML pmml = document.addNewPMML();
PMMLPortObjectSpec.writeHeader(pmml);
pmml.setVersion(PMMLPortObject.PMML_V4_2);
new PMMLDataDictionaryTranslator().exportTo(document, decTreeModel.getSpec());
RuleSetModel newRuleSetModel = pmml.addNewRuleSetModel();
PMMLMiningSchemaTranslator.writeMiningSchema(decTreeModel.getSpec(), newRuleSetModel);
newRuleSetModel.setFunctionName(MININGFUNCTION.CLASSIFICATION);
newRuleSetModel.setAlgorithmName("RuleSet");
RuleSet ruleSet = newRuleSetModel.addNewRuleSet();
ruleSet.addNewRuleSelectionMethod().setCriterion(Criterion.FIRST_HIT);
addRules(ruleSet, new ArrayList<DecisionTreeNode>(), decisionTree.getRootNode());
// TODO: Return a BufferedDataTable for each output port
PMMLPortObject pmmlPortObject = new PMMLPortObject(ruleSetModel.getSpec(), document);
return new PortObject[] { pmmlPortObject, new RuleSetToTable(m_rulesToTable).execute(exec, pmmlPortObject) };
}
use of org.dmg.pmml.RuleSetDocument.RuleSet in project knime-core by knime.
the class PMMLRuleEditorNodeModel method createStreamableOperator.
/**
* {@inheritDoc}
*/
@Override
public StreamableOperator createStreamableOperator(final PartitionInfo partitionInfo, final PortObjectSpec[] inSpecs) throws InvalidSettingsException {
final DataTableSpec tableSpec = (DataTableSpec) inSpecs[0];
return new StreamableOperator() {
private ColumnRearranger m_rearrangerx;
private PMMLPortObject m_portObject;
{
try {
final PMMLDocument doc = PMMLDocument.Factory.newInstance();
final PMML pmml = doc.addNewPMML();
RuleSetModel ruleSetModel = pmml.addNewRuleSetModel();
RuleSet ruleSet = ruleSetModel.addNewRuleSet();
PMMLRuleParser parser = new PMMLRuleParser(tableSpec, getAvailableInputFlowVariables());
m_rearrangerx = createRearranger(tableSpec, ruleSet, parser);
} catch (ParseException e) {
throw new InvalidSettingsException(e);
}
}
@Override
public void runFinal(final PortInput[] inputs, final PortOutput[] outputs, final ExecutionContext exec) throws Exception {
m_rearrangerx.createStreamableFunction(0, 0).runFinal(inputs, outputs, exec);
}
/**
* {@inheritDoc}
*/
@Override
public void loadInternals(final StreamableOperatorInternals internals) {
super.loadInternals(internals);
m_portObject = ((StreamInternalForPMMLPortObject) internals).getObject();
}
/**
* {@inheritDoc}
*/
@Override
public StreamableOperatorInternals saveInternals() {
return createInitialStreamableOperatorInternals().setObject(m_portObject);
}
};
}
use of org.dmg.pmml.RuleSetDocument.RuleSet in project knime-core by knime.
the class PMMLRuleTranslator method exportTo.
/**
* {@inheritDoc}
*/
@Override
public SchemaType exportTo(final PMMLDocument pmmlDoc, final PMMLPortObjectSpec spec) {
m_nameMapper = new DerivedFieldMapper(pmmlDoc);
PMML pmml = pmmlDoc.getPMML();
RuleSetModel ruleSetModel = pmml.addNewRuleSetModel();
PMMLMiningSchemaTranslator.writeMiningSchema(spec, ruleSetModel);
ruleSetModel.setModelName("RuleSet");
ruleSetModel.setFunctionName(MININGFUNCTION.CLASSIFICATION);
RuleSet ruleSet = ruleSetModel.addNewRuleSet();
RuleSelectionMethod ruleSelectionMethod = ruleSet.addNewRuleSelectionMethod();
RuleSet origRs = m_originalRuleModel == null ? null : m_originalRuleModel.getRuleSet();
final List<RuleSelectionMethod> origMethods = origRs == null ? Collections.<RuleSelectionMethod>emptyList() : origRs.getRuleSelectionMethodList();
ruleSelectionMethod.setCriterion(origMethods.isEmpty() ? Criterion.FIRST_HIT : origMethods.get(0).getCriterion());
if (!Double.isNaN(m_recordCount)) {
ruleSet.setRecordCount(m_recordCount);
}
if (!Double.isNaN(m_nbCorrect)) {
ruleSet.setNbCorrect(m_nbCorrect);
}
if (!Double.isNaN(m_defaultConfidence)) {
ruleSet.setDefaultConfidence(m_defaultConfidence);
}
if (m_defaultScore != null) {
ruleSet.setDefaultScore(m_defaultScore);
}
new DerivedFieldMapper(pmmlDoc);
addRules(ruleSet, m_rules);
return RuleSetModel.type;
}
use of org.dmg.pmml.RuleSetDocument.RuleSet in project knime-core by knime.
the class RuleEngine2PortsNodeModel method computeRearrangerWithPMML.
/**
* @param spec
* @param rules
* @param flowVars
* @param ruleIdx
* @param outcomeIdx
* @param confidenceIdx
* @param weightIdx
* @param validationIdx
* @param outputColumnName
* @return
* @throws InterruptedException
* @throws InvalidSettingsException
*/
private Pair<ColumnRearranger, PortObject> computeRearrangerWithPMML(final DataTableSpec spec, final RowInput rules, final Map<String, FlowVariable> flowVars, final int ruleIdx, final int outcomeIdx, final int confidenceIdx, final int weightIdx, final int validationIdx, final String outputColumnName) throws InterruptedException, InvalidSettingsException {
PortObject po;
ColumnRearranger ret;
PMMLDocument doc = PMMLDocument.Factory.newInstance();
final PMML pmmlObj = doc.addNewPMML();
RuleSetModel ruleSetModel = pmmlObj.addNewRuleSetModel();
RuleSet ruleSet = ruleSetModel.addNewRuleSet();
List<DataType> outcomeTypes = new ArrayList<>();
PMMLRuleParser parser = new PMMLRuleParser(spec, flowVars);
int lineNo = 0;
DataRow ruleRow;
while ((ruleRow = rules.poll()) != null) {
++lineNo;
DataCell rule = ruleRow.getCell(ruleIdx);
CheckUtils.checkSetting(!rule.isMissing(), "Missing rule in row: " + ruleRow.getKey());
if (rule instanceof StringValue) {
StringValue ruleText = (StringValue) rule;
String r = ruleText.getStringValue().replaceAll("[\r\n]+", " ");
if (RuleSupport.isComment(r)) {
continue;
}
if (outcomeIdx >= 0) {
r += " => " + m_settings.asStringFailForMissing(ruleRow.getCell(outcomeIdx));
}
ParseState state = new ParseState(r);
try {
PMMLPredicate condition = parser.parseBooleanExpression(state);
SimpleRule simpleRule = ruleSet.addNewSimpleRule();
setCondition(simpleRule, condition);
state.skipWS();
state.consumeText("=>");
state.skipWS();
Expression outcome = parser.parseOutcomeOperand(state, null);
simpleRule.setScore(outcome.toString());
if (confidenceIdx >= 0) {
DataCell confidenceCell = ruleRow.getCell(confidenceIdx);
if (!confidenceCell.isMissing()) {
if (confidenceCell instanceof DoubleValue) {
DoubleValue dv = (DoubleValue) confidenceCell;
double confidence = dv.getDoubleValue();
simpleRule.setConfidence(confidence);
}
}
}
if (weightIdx >= 0) {
DataCell weightCell = ruleRow.getCell(weightIdx);
boolean missing = true;
if (!weightCell.isMissing()) {
if (weightCell instanceof DoubleValue) {
DoubleValue dv = (DoubleValue) weightCell;
double weight = dv.getDoubleValue();
simpleRule.setWeight(weight);
missing = false;
}
}
if (missing && m_settings.isHasDefaultWeight()) {
simpleRule.setWeight(m_settings.getDefaultWeight());
}
}
CheckUtils.checkSetting(outcome.isConstant(), "Outcome is not constant in line " + lineNo + " (" + ruleRow.getKey() + ") for rule: " + rule);
outcomeTypes.add(outcome.getOutputType());
} catch (ParseException e) {
ParseException error = Util.addContext(e, r, lineNo);
throw new InvalidSettingsException("Wrong rule in line: " + ruleRow.getKey() + "\n" + error.getMessage(), error);
}
} else {
CheckUtils.checkSetting(false, "Wrong type (" + rule.getType() + ") of rule: " + rule + "\nin row: " + ruleRow.getKey());
}
}
ColumnRearranger dummy = new ColumnRearranger(spec);
if (!m_settings.isReplaceColumn()) {
dummy.append(new SingleCellFactory(new DataColumnSpecCreator(outputColumnName, RuleEngineNodeModel.computeOutputType(outcomeTypes, computeOutcomeType(rules.getDataTableSpec()), true, m_settings.isDisallowLongOutputForCompatibility())).createSpec()) {
@Override
public DataCell getCell(final DataRow row) {
return null;
}
});
}
PMMLPortObject pmml = createPMMLPortObject(doc, ruleSetModel, ruleSet, parser, dummy.createSpec());
po = pmml;
m_copy = copy(pmml);
String predictionConfidenceColumn = m_settings.getPredictionConfidenceColumn();
if (predictionConfidenceColumn == null || predictionConfidenceColumn.isEmpty()) {
predictionConfidenceColumn = RuleEngine2PortsSettings.DEFAULT_PREDICTION_CONFIDENCE_COLUMN;
}
ret = PMMLRuleSetPredictorNodeModel.createRearranger(pmml, spec, m_settings.isReplaceColumn(), outputColumnName, m_settings.isComputeConfidence(), DataTableSpec.getUniqueColumnName(dummy.createSpec(), predictionConfidenceColumn), validationIdx);
return Pair.create(ret, po);
}
Aggregations