use of org.knime.base.node.mine.decisiontree2.PMMLPredicate in project knime-core by knime.
the class TreeNodeNumericConditionTest method testToPMML.
/**
* This method tests the {@link TreeNodeNumericCondition#toPMMLPredicate()} method.
*
* @throws Exception
*/
@Test
public void testToPMML() throws Exception {
final TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
final TestDataGenerator dataGen = new TestDataGenerator(config);
final TreeNumericColumnData col = dataGen.createNumericAttributeColumn("1,2,3,4,4,5,6,7", "testCol", 0);
TreeNodeNumericCondition cond = new TreeNodeNumericCondition(col.getMetaData(), 3, NumericOperator.LessThanOrEqual, false);
PMMLPredicate predicate = cond.toPMMLPredicate();
assertThat(predicate, instanceOf(PMMLSimplePredicate.class));
PMMLSimplePredicate simplePredicate = (PMMLSimplePredicate) predicate;
assertEquals("Wrong attribute", col.getMetaData().getAttributeName(), simplePredicate.getSplitAttribute());
assertEquals("Wrong operator", PMMLOperator.LESS_OR_EQUAL, simplePredicate.getOperator());
assertEquals("Wrong threshold", Double.toString(3), simplePredicate.getThreshold());
cond = new TreeNodeNumericCondition(col.getMetaData(), 4.5, NumericOperator.LargerThan, true);
predicate = cond.toPMMLPredicate();
assertThat(predicate, instanceOf(PMMLCompoundPredicate.class));
PMMLCompoundPredicate compound = (PMMLCompoundPredicate) predicate;
assertEquals("Wrong boolean operator in compound.", PMMLBooleanOperator.OR, compound.getBooleanOperator());
List<PMMLPredicate> preds = compound.getPredicates();
assertEquals("Wrong number of predicates in compound.", 2, preds.size());
assertThat(preds.get(0), instanceOf(PMMLSimplePredicate.class));
simplePredicate = (PMMLSimplePredicate) preds.get(0);
assertEquals("Wrong attribute", col.getMetaData().getAttributeName(), simplePredicate.getSplitAttribute());
assertEquals("Wrong operator", PMMLOperator.GREATER_THAN, simplePredicate.getOperator());
assertEquals("Wrong threshold", Double.toString(4.5), simplePredicate.getThreshold());
assertThat(preds.get(1), instanceOf(PMMLSimplePredicate.class));
simplePredicate = (PMMLSimplePredicate) preds.get(1);
assertEquals("Should be isMissing", PMMLOperator.IS_MISSING, simplePredicate.getOperator());
}
use of org.knime.base.node.mine.decisiontree2.PMMLPredicate in project knime-core by knime.
the class TreeNodeNominalBinaryCondition method toPMMLPredicate.
/**
* {@inheritDoc}
*/
@Override
public PMMLPredicate toPMMLPredicate() {
final PMMLSimpleSetPredicate setPredicate = new PMMLSimpleSetPredicate(getAttributeName(), m_setLogic.getPmmlSetOperator());
setPredicate.setValues(Arrays.asList(getValues()));
setPredicate.setArrayType(PMMLArrayType.STRING);
if (!acceptsMissings()) {
// if condition rejects missing values return the set predicate
return setPredicate;
}
// otherwise create compound condition that allows missing values
final PMMLCompoundPredicate compPredicate = new PMMLCompoundPredicate(PMMLBooleanOperator.OR);
final PMMLSimplePredicate missing = new PMMLSimplePredicate();
missing.setSplitAttribute(getAttributeName());
missing.setOperator(PMMLOperator.IS_MISSING);
compPredicate.addPredicate(setPredicate);
compPredicate.addPredicate(missing);
return compPredicate;
}
use of org.knime.base.node.mine.decisiontree2.PMMLPredicate in project knime-core by knime.
the class RuleEngine2PortsNodeModel method computeRearrangerWithPMML.
/**
* @param spec
* @param rules
* @param flowVars
* @param ruleIdx
* @param outcomeIdx
* @param confidenceIdx
* @param weightIdx
* @param validationIdx
* @param outputColumnName
* @return
* @throws InterruptedException
* @throws InvalidSettingsException
*/
private Pair<ColumnRearranger, PortObject> computeRearrangerWithPMML(final DataTableSpec spec, final RowInput rules, final Map<String, FlowVariable> flowVars, final int ruleIdx, final int outcomeIdx, final int confidenceIdx, final int weightIdx, final int validationIdx, final String outputColumnName) throws InterruptedException, InvalidSettingsException {
PortObject po;
ColumnRearranger ret;
PMMLDocument doc = PMMLDocument.Factory.newInstance();
final PMML pmmlObj = doc.addNewPMML();
RuleSetModel ruleSetModel = pmmlObj.addNewRuleSetModel();
RuleSet ruleSet = ruleSetModel.addNewRuleSet();
List<DataType> outcomeTypes = new ArrayList<>();
PMMLRuleParser parser = new PMMLRuleParser(spec, flowVars);
int lineNo = 0;
DataRow ruleRow;
while ((ruleRow = rules.poll()) != null) {
++lineNo;
DataCell rule = ruleRow.getCell(ruleIdx);
CheckUtils.checkSetting(!rule.isMissing(), "Missing rule in row: " + ruleRow.getKey());
if (rule instanceof StringValue) {
StringValue ruleText = (StringValue) rule;
String r = ruleText.getStringValue().replaceAll("[\r\n]+", " ");
if (RuleSupport.isComment(r)) {
continue;
}
if (outcomeIdx >= 0) {
r += " => " + m_settings.asStringFailForMissing(ruleRow.getCell(outcomeIdx));
}
ParseState state = new ParseState(r);
try {
PMMLPredicate condition = parser.parseBooleanExpression(state);
SimpleRule simpleRule = ruleSet.addNewSimpleRule();
setCondition(simpleRule, condition);
state.skipWS();
state.consumeText("=>");
state.skipWS();
Expression outcome = parser.parseOutcomeOperand(state, null);
simpleRule.setScore(outcome.toString());
if (confidenceIdx >= 0) {
DataCell confidenceCell = ruleRow.getCell(confidenceIdx);
if (!confidenceCell.isMissing()) {
if (confidenceCell instanceof DoubleValue) {
DoubleValue dv = (DoubleValue) confidenceCell;
double confidence = dv.getDoubleValue();
simpleRule.setConfidence(confidence);
}
}
}
if (weightIdx >= 0) {
DataCell weightCell = ruleRow.getCell(weightIdx);
boolean missing = true;
if (!weightCell.isMissing()) {
if (weightCell instanceof DoubleValue) {
DoubleValue dv = (DoubleValue) weightCell;
double weight = dv.getDoubleValue();
simpleRule.setWeight(weight);
missing = false;
}
}
if (missing && m_settings.isHasDefaultWeight()) {
simpleRule.setWeight(m_settings.getDefaultWeight());
}
}
CheckUtils.checkSetting(outcome.isConstant(), "Outcome is not constant in line " + lineNo + " (" + ruleRow.getKey() + ") for rule: " + rule);
outcomeTypes.add(outcome.getOutputType());
} catch (ParseException e) {
ParseException error = Util.addContext(e, r, lineNo);
throw new InvalidSettingsException("Wrong rule in line: " + ruleRow.getKey() + "\n" + error.getMessage(), error);
}
} else {
CheckUtils.checkSetting(false, "Wrong type (" + rule.getType() + ") of rule: " + rule + "\nin row: " + ruleRow.getKey());
}
}
ColumnRearranger dummy = new ColumnRearranger(spec);
if (!m_settings.isReplaceColumn()) {
dummy.append(new SingleCellFactory(new DataColumnSpecCreator(outputColumnName, RuleEngineNodeModel.computeOutputType(outcomeTypes, computeOutcomeType(rules.getDataTableSpec()), true, m_settings.isDisallowLongOutputForCompatibility())).createSpec()) {
@Override
public DataCell getCell(final DataRow row) {
return null;
}
});
}
PMMLPortObject pmml = createPMMLPortObject(doc, ruleSetModel, ruleSet, parser, dummy.createSpec());
po = pmml;
m_copy = copy(pmml);
String predictionConfidenceColumn = m_settings.getPredictionConfidenceColumn();
if (predictionConfidenceColumn == null || predictionConfidenceColumn.isEmpty()) {
predictionConfidenceColumn = RuleEngine2PortsSettings.DEFAULT_PREDICTION_CONFIDENCE_COLUMN;
}
ret = PMMLRuleSetPredictorNodeModel.createRearranger(pmml, spec, m_settings.isReplaceColumn(), outputColumnName, m_settings.isComputeConfidence(), DataTableSpec.getUniqueColumnName(dummy.createSpec(), predictionConfidenceColumn), validationIdx);
return Pair.create(ret, po);
}
use of org.knime.base.node.mine.decisiontree2.PMMLPredicate in project knime-core by knime.
the class PMMLExpressionFactory method in.
/**
* {@inheritDoc}
*/
@Override
public PMMLPredicate in(final Expression left, final Expression right) {
PMMLSimpleSetPredicate setIn = new PMMLSimpleSetPredicate(expressionToString(left), PMMLSetOperator.IS_IN);
if (left.getTreeType() == ASTType.ColRef) {
m_usedColumns.add(expressionToString(left));
} else {
throw new UnsupportedOperationException("PMML 4.1 supports only columns before IN.");
}
if (!right.isConstant()) {
throw new UnsupportedOperationException("PMML 4.1 supports only constants in arguments.");
}
List<Expression> children = right.getChildren();
List<String> values = new ArrayList<String>(children.size());
List<DataType> types = new ArrayList<DataType>(children.size());
for (Expression child : children) {
values.add(expressionToString(child));
types.add(child.getOutputType());
}
DataType outputType = RuleEngineNodeModel.computeOutputType(types, false);
if (outputType.isCompatible(IntValue.class)) {
setIn.setArrayType(PMMLArrayType.INT);
} else if (outputType.isCompatible(DoubleValue.class)) {
setIn.setArrayType(PMMLArrayType.REAL);
} else {
setIn.setArrayType(PMMLArrayType.STRING);
}
setIn.setValues(values);
return setIn;
}
use of org.knime.base.node.mine.decisiontree2.PMMLPredicate in project knime-core by knime.
the class PMMLExpressionFactory method createConnective.
/**
* Creates a "predicate" from logical connectives.
*
* @param boolExpressions The expressions to combine.
* @param op The operator.
* @return The {@link PMMLRuleCompoundPredicate} representing the arguments.
*/
private PMMLPredicate createConnective(final List<PMMLPredicate> boolExpressions, final PMMLBooleanOperator op) {
final PMMLCompoundPredicate ret = new PMMLCompoundPredicate(op);
ret.setPredicates(new LinkedList<PMMLPredicate>(boolExpressions));
return ret;
}
Aggregations