use of org.dmg.pmml.SimpleSetPredicate in project shifu by ShifuML.
the class TreeNodePmmlElementCreator method convert.
public org.dmg.pmml.tree.Node convert(Node node, boolean isLeft, Split split) {
org.dmg.pmml.tree.Node pmmlNode = new org.dmg.pmml.tree.Node();
pmmlNode.setId(String.valueOf(node.getId()));
if (node.getPredict() != null) {
pmmlNode.setScore(String.valueOf(treeModel.isClassification() ? node.getPredict().getClassValue() : node.getPredict().getPredict()));
}
pmmlNode.setDefaultChild(null);
Predicate predicate = null;
ColumnConfig columnConfig = this.columnConfigList.get(split.getColumnNum());
if (columnConfig.isNumerical()) {
SimplePredicate p = new SimplePredicate();
p.setValue(String.valueOf(split.getThreshold()));
// TODO, how to support segment variable in tree model, here should be changed
p.setField(new FieldName(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
if (isLeft) {
p.setOperator(SimplePredicate.Operator.fromValue("lessThan"));
} else {
p.setOperator(SimplePredicate.Operator.fromValue("greaterOrEqual"));
}
predicate = p;
} else if (columnConfig.isCategorical()) {
SimpleSetPredicate p = new SimpleSetPredicate();
Set<Short> childCategories = split.getLeftOrRightCategories();
// TODO, how to support segment variable in tree model, here should be changed
p.setField(new FieldName(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
StringBuilder arrayStr = new StringBuilder();
List<String> valueList = treeModel.getCategoricalColumnNameNames().get(columnConfig.getColumnNum());
for (Short sh : childCategories) {
if (sh >= valueList.size()) {
arrayStr.append(" \"\"");
continue;
}
String s = valueList.get(sh);
arrayStr.append(" ");
if (s.contains("\"")) {
String tmp = s.replaceAll("\"", "\\\\\\\"");
if (s.contains(" ")) {
arrayStr.append("\"");
arrayStr.append(tmp);
arrayStr.append("\"");
} else {
arrayStr.append(tmp);
}
} else {
if (s.contains(" ")) {
arrayStr.append("\"");
arrayStr.append(s);
arrayStr.append("\"");
} else {
arrayStr.append(s);
}
}
}
Array array = new Array(Array.Type.fromValue("string"), arrayStr.toString().trim());
p.setArray(array);
if (isLeft) {
if (split.isLeft()) {
p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isIn"));
} else {
p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isNotIn"));
}
} else {
if (split.isLeft()) {
p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isNotIn"));
} else {
p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isIn"));
}
}
predicate = p;
}
pmmlNode.setPredicate(predicate);
if (node.getSplit() == null || node.isRealLeaf()) {
return pmmlNode;
}
List<org.dmg.pmml.tree.Node> childList = pmmlNode.getNodes();
org.dmg.pmml.tree.Node left = convert(node.getLeft(), true, node.getSplit());
org.dmg.pmml.tree.Node right = convert(node.getRight(), false, node.getSplit());
childList.add(left);
childList.add(right);
return pmmlNode;
}
use of org.dmg.pmml.SimpleSetPredicate in project drools by kiegroup.
the class PMMLModelTestUtils method getSimpleSetPredicate.
public static SimpleSetPredicate getSimpleSetPredicate(final String predicateName, final Array.Type arrayType, final List<String> values, final SimpleSetPredicate.BooleanOperator booleanOperator) {
FieldName fieldName = FieldName.create(predicateName);
SimpleSetPredicate toReturn = new SimpleSetPredicate();
toReturn.setField(fieldName);
toReturn.setBooleanOperator(booleanOperator);
Array array = getArray(arrayType, values);
toReturn.setArray(array);
return toReturn;
}
use of org.dmg.pmml.SimpleSetPredicate in project drools by kiegroup.
the class KiePMMLSimpleSetPredicateFactoryTest method getSimpleSetPredicate.
public static SimpleSetPredicate getSimpleSetPredicate(List<String> values, final Array.Type arrayType, final SimpleSetPredicate.BooleanOperator inNotIn) {
Array array = getArray(arrayType, values);
SimpleSetPredicate toReturn = new SimpleSetPredicate();
toReturn.setField(FieldName.create(SIMPLE_SET_PREDICATE_NAME));
toReturn.setBooleanOperator(inNotIn);
toReturn.setArray(array);
return toReturn;
}
use of org.dmg.pmml.SimpleSetPredicate in project drools by kiegroup.
the class KiePMMLSimpleSetPredicateFactoryTest method getSimpleSetPredicateVariableDeclaration.
@Test
public void getSimpleSetPredicateVariableDeclaration() throws IOException {
String variableName = "variableName";
Array.Type arrayType = Array.Type.STRING;
List<String> values = getStringObjects(arrayType, 4);
SimpleSetPredicate simpleSetPredicate = getSimpleSetPredicate(values, arrayType, SimpleSetPredicate.BooleanOperator.IS_IN);
String arrayTypeString = ARRAY_TYPE.class.getName() + "." + ARRAY_TYPE.byName(simpleSetPredicate.getArray().getType().value());
String booleanOperatorString = IN_NOTIN.class.getName() + "." + IN_NOTIN.byName(simpleSetPredicate.getBooleanOperator().value());
String valuesString = values.stream().map(valueString -> "\"" + valueString + "\"").collect(Collectors.joining(","));
DataField dataField = new DataField();
dataField.setName(simpleSetPredicate.getField());
dataField.setDataType(DataType.DOUBLE);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
BlockStmt retrieved = KiePMMLSimpleSetPredicateFactory.getSimpleSetPredicateVariableDeclaration(variableName, simpleSetPredicate);
String text = getFileContent(TEST_01_SOURCE);
Statement expected = JavaParserUtils.parseBlock(String.format(text, variableName, simpleSetPredicate.getField().getValue(), arrayTypeString, booleanOperatorString, valuesString));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
List<Class<?>> imports = Arrays.asList(KiePMMLSimpleSetPredicate.class, Arrays.class, Collections.class);
commonValidateCompilationWithImports(retrieved, imports);
}
use of org.dmg.pmml.SimpleSetPredicate in project drools by kiegroup.
the class KiePMMLSimpleSetPredicateInstanceFactoryTest method getKiePMMLSimpleSetPredicate.
@Test
public void getKiePMMLSimpleSetPredicate() {
final SimpleSetPredicate toConvert = getRandomSimpleSetPredicate();
final KiePMMLSimpleSetPredicate retrieved = KiePMMLSimpleSetPredicateInstanceFactory.getKiePMMLSimpleSetPredicate(toConvert);
commonVerifyKiePMMLSimpleSetPredicate(retrieved, toConvert);
}
Aggregations