use of org.dmg.pmml.CompoundPredicate in project drools by kiegroup.
the class KiePMMLCompoundPredicateASTFactoryTest method declareRuleFromCompoundPredicateAndOrXorNotFinalLeaf.
@Test
public void declareRuleFromCompoundPredicateAndOrXorNotFinalLeaf() {
final Map<String, KiePMMLOriginalTypeGeneratedType> fieldTypeMap = new HashMap<>();
final List<SimplePredicate> predicates = getSimplePredicates(fieldTypeMap);
String parentPath = "_will play";
String currentRule = "_will play_will play";
String result = "RESULT";
for (CompoundPredicate.BooleanOperator operator : CompoundPredicate.BooleanOperator.values()) {
if (operator.equals(CompoundPredicate.BooleanOperator.SURROGATE)) {
continue;
}
CompoundPredicate compoundPredicate = new CompoundPredicate();
compoundPredicate.setBooleanOperator(operator);
predicates.forEach(compoundPredicate::addPredicates);
final List<KiePMMLDroolsRule> rules = new ArrayList<>();
PredicateASTFactoryData predicateASTFactoryData = getPredicateASTFactoryData(compoundPredicate, Collections.emptyList(), rules, parentPath, currentRule, fieldTypeMap);
KiePMMLCompoundPredicateASTFactory.factory(predicateASTFactoryData).declareRuleFromCompoundPredicate(result, false);
assertEquals(1, rules.size());
final KiePMMLDroolsRule retrieved = rules.get(0);
assertNotNull(retrieved);
assertEquals(currentRule, retrieved.getName());
assertEquals(currentRule, retrieved.getStatusToSet());
assertEquals(String.format(STATUS_PATTERN, parentPath), retrieved.getStatusConstraint());
switch(compoundPredicate.getBooleanOperator()) {
case AND:
assertNotNull(retrieved.getAndConstraints());
break;
case OR:
assertNotNull(retrieved.getOrConstraints());
break;
case XOR:
assertNotNull(retrieved.getXorConstraints());
break;
default:
}
}
}
use of org.dmg.pmml.CompoundPredicate in project drools by kiegroup.
the class KiePMMLCompoundPredicateASTFactoryTest method declareRuleFromCompoundPredicateAndOrXorFinalLeaf.
@Test
public void declareRuleFromCompoundPredicateAndOrXorFinalLeaf() {
final Map<String, KiePMMLOriginalTypeGeneratedType> fieldTypeMap = new HashMap<>();
final List<SimplePredicate> predicates = getSimplePredicates(fieldTypeMap);
String parentPath = "_will play";
String currentRule = "_will play_will play";
String result = "RESULT";
for (CompoundPredicate.BooleanOperator operator : CompoundPredicate.BooleanOperator.values()) {
if (operator.equals(CompoundPredicate.BooleanOperator.SURROGATE)) {
continue;
}
CompoundPredicate compoundPredicate = new CompoundPredicate();
compoundPredicate.setBooleanOperator(operator);
predicates.forEach(compoundPredicate::addPredicates);
final List<KiePMMLDroolsRule> rules = new ArrayList<>();
PredicateASTFactoryData predicateASTFactoryData = getPredicateASTFactoryData(compoundPredicate, Collections.emptyList(), rules, parentPath, currentRule, fieldTypeMap);
KiePMMLCompoundPredicateASTFactory.factory(predicateASTFactoryData).declareRuleFromCompoundPredicate(result, true);
assertEquals(1, rules.size());
final KiePMMLDroolsRule retrieved = rules.get(0);
assertNotNull(retrieved);
assertEquals(currentRule, retrieved.getName());
assertEquals(DONE, retrieved.getStatusToSet());
assertEquals(String.format(STATUS_PATTERN, parentPath), retrieved.getStatusConstraint());
assertEquals(result, retrieved.getResult());
assertEquals(ResultCode.OK, retrieved.getResultCode());
switch(compoundPredicate.getBooleanOperator()) {
case AND:
assertNotNull(retrieved.getAndConstraints());
break;
case OR:
assertNotNull(retrieved.getOrConstraints());
break;
case XOR:
assertNotNull(retrieved.getXorConstraints());
break;
default:
}
}
}
use of org.dmg.pmml.CompoundPredicate in project jpmml-r by jpmml.
the class RPartConverter method makeDefault.
private void makeDefault(Node node) {
Predicate predicate = node.requirePredicate();
CompoundPredicate compoundPredicate;
if (predicate instanceof CompoundPredicate) {
compoundPredicate = (CompoundPredicate) predicate;
} else {
compoundPredicate = new CompoundPredicate(CompoundPredicate.BooleanOperator.SURROGATE, null).addPredicates(predicate);
node.setPredicate(compoundPredicate);
}
compoundPredicate.addPredicates(True.INSTANCE);
}
use of org.dmg.pmml.CompoundPredicate in project jpmml-r by jpmml.
the class RPartConverter method encodeNode.
private Node encodeNode(Predicate predicate, int rowName, RIntegerVector rowNames, RVector<?> var, RIntegerVector n, int[][] splitInfo, RNumberVector<?> splits, RIntegerVector csplit, ScoreEncoder scoreEncoder, Schema schema) {
int offset = getIndex(rowNames, rowName);
Integer id = Integer.valueOf(rowName);
List<? extends Feature> features = schema.getFeatures();
int splitVar = getFeatureIndex(var, offset, features);
if (splitVar == RPartConverter.INDEX_LEAF) {
Node result = new CountingLeafNode(null, predicate).setId(id);
return scoreEncoder.encode(result, offset);
}
int leftRowName = rowName * 2;
int rightRowName = (rowName * 2) + 1;
Integer majorityDir = null;
if (this.useSurrogate == 2) {
int leftOffset = getIndex(rowNames, leftRowName);
int rightOffset = getIndex(rowNames, rightRowName);
majorityDir = Double.compare(n.getValue(leftOffset), n.getValue(rightOffset));
}
Feature feature = features.get(splitVar - 1);
int splitOffset = splitInfo[offset][0];
int splitNumCompete = splitInfo[offset][1];
int splitNumSurrogate = splitInfo[offset][2];
List<Predicate> predicates = encodePredicates(feature, splitOffset, splits, csplit);
Predicate leftPredicate = predicates.get(0);
Predicate rightPredicate = predicates.get(1);
if (this.useSurrogate > 0 && splitNumSurrogate > 0) {
CompoundPredicate leftCompoundPredicate = new CompoundPredicate(CompoundPredicate.BooleanOperator.SURROGATE, null).addPredicates(leftPredicate);
CompoundPredicate rightCompoundPredicate = new CompoundPredicate(CompoundPredicate.BooleanOperator.SURROGATE, null).addPredicates(rightPredicate);
RStringVector splitRowNames = splits.dimnames(0);
for (int i = 0; i < splitNumSurrogate; i++) {
int surrogateSplitOffset = (splitOffset + 1) + splitNumCompete + i;
feature = getFeature(splitRowNames.getValue(surrogateSplitOffset));
predicates = encodePredicates(feature, surrogateSplitOffset, splits, csplit);
leftCompoundPredicate.addPredicates(predicates.get(0));
rightCompoundPredicate.addPredicates(predicates.get(1));
}
leftPredicate = leftCompoundPredicate;
rightPredicate = rightCompoundPredicate;
}
Node leftChild = encodeNode(leftPredicate, leftRowName, rowNames, var, n, splitInfo, splits, csplit, scoreEncoder, schema);
Node rightChild = encodeNode(rightPredicate, rightRowName, rowNames, var, n, splitInfo, splits, csplit, scoreEncoder, schema);
if (this.useSurrogate == 2) {
if (majorityDir < 0) {
makeDefault(rightChild);
} else if (majorityDir > 0) {
Node tmp = leftChild;
makeDefault(leftChild);
leftChild = rightChild;
rightChild = tmp;
}
}
Node result = new CountingBranchNode(null, predicate).setId(id).addNodes(leftChild, rightChild);
return scoreEncoder.encode(result, offset);
}
use of org.dmg.pmml.CompoundPredicate in project drools by kiegroup.
the class KiePMMLCompoundPredicateFactoryTest method getCompoundPredicateVariableDeclaration.
@Test
public void getCompoundPredicateVariableDeclaration() throws IOException {
String variableName = "variableName";
SimplePredicate simplePredicate1 = getSimplePredicate(PARAM_1, value1, operator1);
SimplePredicate simplePredicate2 = getSimplePredicate(PARAM_2, value2, operator2);
Array.Type arrayType = Array.Type.STRING;
List<String> values = getStringObjects(arrayType, 4);
SimpleSetPredicate simpleSetPredicate = getSimpleSetPredicate(values, arrayType, SimpleSetPredicate.BooleanOperator.IS_IN);
CompoundPredicate compoundPredicate = new CompoundPredicate();
compoundPredicate.setBooleanOperator(CompoundPredicate.BooleanOperator.AND);
compoundPredicate.getPredicates().add(0, simplePredicate1);
compoundPredicate.getPredicates().add(1, simplePredicate2);
compoundPredicate.getPredicates().add(2, simpleSetPredicate);
DataField dataField1 = new DataField();
dataField1.setName(simplePredicate1.getField());
dataField1.setDataType(DataType.DOUBLE);
DataField dataField2 = new DataField();
dataField2.setName(simplePredicate2.getField());
dataField2.setDataType(DataType.DOUBLE);
DataField dataField3 = new DataField();
dataField3.setName(simpleSetPredicate.getField());
dataField3.setDataType(DataType.DOUBLE);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField1, dataField2, dataField3);
String booleanOperatorString = BOOLEAN_OPERATOR.class.getName() + "." + BOOLEAN_OPERATOR.byName(compoundPredicate.getBooleanOperator().value()).name();
String valuesString = values.stream().map(valueString -> "\"" + valueString + "\"").collect(Collectors.joining(","));
final List<Field<?>> fields = getFieldsFromDataDictionary(dataDictionary);
BlockStmt retrieved = KiePMMLCompoundPredicateFactory.getCompoundPredicateVariableDeclaration(variableName, compoundPredicate, fields);
String text = getFileContent(TEST_01_SOURCE);
Statement expected = JavaParserUtils.parseBlock(String.format(text, variableName, valuesString, booleanOperatorString));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
List<Class<?>> imports = Arrays.asList(KiePMMLCompoundPredicate.class, KiePMMLSimplePredicate.class, KiePMMLSimpleSetPredicate.class, Arrays.class, Collections.class);
commonValidateCompilationWithImports(retrieved, imports);
}
Aggregations