Search in sources :

Example 1 with SimpleSetPredicate

use of org.dmg.pmml.SimpleSetPredicate in project shifu by ShifuML.

the class TreeNodePmmlElementCreator method convert.

public org.dmg.pmml.tree.Node convert(Node node, boolean isLeft, Split split) {
    org.dmg.pmml.tree.Node pmmlNode = new org.dmg.pmml.tree.Node();
    pmmlNode.setId(String.valueOf(node.getId()));
    if (node.getPredict() != null) {
        pmmlNode.setScore(String.valueOf(treeModel.isClassification() ? node.getPredict().getClassValue() : node.getPredict().getPredict()));
    }
    pmmlNode.setDefaultChild(null);
    Predicate predicate = null;
    ColumnConfig columnConfig = this.columnConfigList.get(split.getColumnNum());
    if (columnConfig.isNumerical()) {
        SimplePredicate p = new SimplePredicate();
        p.setValue(String.valueOf(split.getThreshold()));
        // TODO, how to support segment variable in tree model, here should be changed
        p.setField(new FieldName(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
        if (isLeft) {
            p.setOperator(SimplePredicate.Operator.fromValue("lessThan"));
        } else {
            p.setOperator(SimplePredicate.Operator.fromValue("greaterOrEqual"));
        }
        predicate = p;
    } else if (columnConfig.isCategorical()) {
        SimpleSetPredicate p = new SimpleSetPredicate();
        Set<Short> childCategories = split.getLeftOrRightCategories();
        // TODO, how to support segment variable in tree model, here should be changed
        p.setField(new FieldName(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
        StringBuilder arrayStr = new StringBuilder();
        List<String> valueList = treeModel.getCategoricalColumnNameNames().get(columnConfig.getColumnNum());
        for (Short sh : childCategories) {
            if (sh >= valueList.size()) {
                arrayStr.append(" \"\"");
                continue;
            }
            String s = valueList.get(sh);
            arrayStr.append(" ");
            if (s.contains("\"")) {
                String tmp = s.replaceAll("\"", "\\\\\\\"");
                if (s.contains(" ")) {
                    arrayStr.append("\"");
                    arrayStr.append(tmp);
                    arrayStr.append("\"");
                } else {
                    arrayStr.append(tmp);
                }
            } else {
                if (s.contains(" ")) {
                    arrayStr.append("\"");
                    arrayStr.append(s);
                    arrayStr.append("\"");
                } else {
                    arrayStr.append(s);
                }
            }
        }
        Array array = new Array(Array.Type.fromValue("string"), arrayStr.toString().trim());
        p.setArray(array);
        if (isLeft) {
            if (split.isLeft()) {
                p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isIn"));
            } else {
                p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isNotIn"));
            }
        } else {
            if (split.isLeft()) {
                p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isNotIn"));
            } else {
                p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isIn"));
            }
        }
        predicate = p;
    }
    pmmlNode.setPredicate(predicate);
    if (node.getSplit() == null || node.isRealLeaf()) {
        return pmmlNode;
    }
    List<org.dmg.pmml.tree.Node> childList = pmmlNode.getNodes();
    org.dmg.pmml.tree.Node left = convert(node.getLeft(), true, node.getSplit());
    org.dmg.pmml.tree.Node right = convert(node.getRight(), false, node.getSplit());
    childList.add(left);
    childList.add(right);
    return pmmlNode;
}
Also used : Set(java.util.Set) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) Node(ml.shifu.shifu.core.dtrain.dt.Node) SimplePredicate(org.dmg.pmml.SimplePredicate) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate) Array(org.dmg.pmml.Array) List(java.util.List) FieldName(org.dmg.pmml.FieldName)

Example 2 with SimpleSetPredicate

use of org.dmg.pmml.SimpleSetPredicate in project drools by kiegroup.

the class PMMLModelTestUtils method getSimpleSetPredicate.

public static SimpleSetPredicate getSimpleSetPredicate(final String predicateName, final Array.Type arrayType, final List<String> values, final SimpleSetPredicate.BooleanOperator booleanOperator) {
    FieldName fieldName = FieldName.create(predicateName);
    SimpleSetPredicate toReturn = new SimpleSetPredicate();
    toReturn.setField(fieldName);
    toReturn.setBooleanOperator(booleanOperator);
    Array array = getArray(arrayType, values);
    toReturn.setArray(array);
    return toReturn;
}
Also used : Array(org.dmg.pmml.Array) FieldName(org.dmg.pmml.FieldName) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate)

Example 3 with SimpleSetPredicate

use of org.dmg.pmml.SimpleSetPredicate in project drools by kiegroup.

the class KiePMMLSimpleSetPredicateFactoryTest method getSimpleSetPredicate.

public static SimpleSetPredicate getSimpleSetPredicate(List<String> values, final Array.Type arrayType, final SimpleSetPredicate.BooleanOperator inNotIn) {
    Array array = getArray(arrayType, values);
    SimpleSetPredicate toReturn = new SimpleSetPredicate();
    toReturn.setField(FieldName.create(SIMPLE_SET_PREDICATE_NAME));
    toReturn.setBooleanOperator(inNotIn);
    toReturn.setArray(array);
    return toReturn;
}
Also used : Array(org.dmg.pmml.Array) PMMLModelTestUtils.getArray(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getArray) KiePMMLSimpleSetPredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimpleSetPredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate)

Example 4 with SimpleSetPredicate

use of org.dmg.pmml.SimpleSetPredicate in project drools by kiegroup.

the class KiePMMLSimpleSetPredicateFactoryTest method getSimpleSetPredicateVariableDeclaration.

@Test
public void getSimpleSetPredicateVariableDeclaration() throws IOException {
    String variableName = "variableName";
    Array.Type arrayType = Array.Type.STRING;
    List<String> values = getStringObjects(arrayType, 4);
    SimpleSetPredicate simpleSetPredicate = getSimpleSetPredicate(values, arrayType, SimpleSetPredicate.BooleanOperator.IS_IN);
    String arrayTypeString = ARRAY_TYPE.class.getName() + "." + ARRAY_TYPE.byName(simpleSetPredicate.getArray().getType().value());
    String booleanOperatorString = IN_NOTIN.class.getName() + "." + IN_NOTIN.byName(simpleSetPredicate.getBooleanOperator().value());
    String valuesString = values.stream().map(valueString -> "\"" + valueString + "\"").collect(Collectors.joining(","));
    DataField dataField = new DataField();
    dataField.setName(simpleSetPredicate.getField());
    dataField.setDataType(DataType.DOUBLE);
    DataDictionary dataDictionary = new DataDictionary();
    dataDictionary.addDataFields(dataField);
    BlockStmt retrieved = KiePMMLSimpleSetPredicateFactory.getSimpleSetPredicateVariableDeclaration(variableName, simpleSetPredicate);
    String text = getFileContent(TEST_01_SOURCE);
    Statement expected = JavaParserUtils.parseBlock(String.format(text, variableName, simpleSetPredicate.getField().getValue(), arrayTypeString, booleanOperatorString, valuesString));
    assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
    List<Class<?>> imports = Arrays.asList(KiePMMLSimpleSetPredicate.class, Arrays.class, Collections.class);
    commonValidateCompilationWithImports(retrieved, imports);
}
Also used : Array(org.dmg.pmml.Array) PMMLModelTestUtils.getArray(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getArray) ARRAY_TYPE(org.kie.pmml.api.enums.ARRAY_TYPE) Arrays(java.util.Arrays) JavaParserUtils(org.kie.pmml.compiler.commons.utils.JavaParserUtils) DataType(org.dmg.pmml.DataType) KiePMMLSimpleSetPredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimpleSetPredicate) Assert.assertTrue(org.junit.Assert.assertTrue) IOException(java.io.IOException) DataDictionary(org.dmg.pmml.DataDictionary) Test(org.junit.Test) Statement(com.github.javaparser.ast.stmt.Statement) Collectors(java.util.stream.Collectors) Array(org.dmg.pmml.Array) FileUtils.getFileContent(org.kie.test.util.filesystem.FileUtils.getFileContent) DataField(org.dmg.pmml.DataField) FieldName(org.dmg.pmml.FieldName) List(java.util.List) PMMLModelTestUtils.getStringObjects(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getStringObjects) CodegenTestUtils.commonValidateCompilationWithImports(org.kie.pmml.compiler.commons.testutils.CodegenTestUtils.commonValidateCompilationWithImports) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) IN_NOTIN(org.kie.pmml.api.enums.IN_NOTIN) PMMLModelTestUtils.getArray(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getArray) Collections(java.util.Collections) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate) DataField(org.dmg.pmml.DataField) Statement(com.github.javaparser.ast.stmt.Statement) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) DataDictionary(org.dmg.pmml.DataDictionary) KiePMMLSimpleSetPredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimpleSetPredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate) Test(org.junit.Test)

Example 5 with SimpleSetPredicate

use of org.dmg.pmml.SimpleSetPredicate in project drools by kiegroup.

the class KiePMMLSimpleSetPredicateInstanceFactoryTest method getKiePMMLSimpleSetPredicate.

@Test
public void getKiePMMLSimpleSetPredicate() {
    final SimpleSetPredicate toConvert = getRandomSimpleSetPredicate();
    final KiePMMLSimpleSetPredicate retrieved = KiePMMLSimpleSetPredicateInstanceFactory.getKiePMMLSimpleSetPredicate(toConvert);
    commonVerifyKiePMMLSimpleSetPredicate(retrieved, toConvert);
}
Also used : InstanceFactoriesTestCommon.commonVerifyKiePMMLSimpleSetPredicate(org.kie.pmml.compiler.commons.factories.InstanceFactoriesTestCommon.commonVerifyKiePMMLSimpleSetPredicate) KiePMMLSimpleSetPredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimpleSetPredicate) InstanceFactoriesTestCommon.commonVerifyKiePMMLSimpleSetPredicate(org.kie.pmml.compiler.commons.factories.InstanceFactoriesTestCommon.commonVerifyKiePMMLSimpleSetPredicate) KiePMMLSimpleSetPredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimpleSetPredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate) PMMLModelTestUtils.getRandomSimpleSetPredicate(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomSimpleSetPredicate) Test(org.junit.Test)

Aggregations

SimpleSetPredicate (org.dmg.pmml.SimpleSetPredicate)21 List (java.util.List)11 Test (org.junit.Test)11 SimplePredicate (org.dmg.pmml.SimplePredicate)10 Array (org.dmg.pmml.Array)9 KiePMMLSimpleSetPredicate (org.kie.pmml.commons.model.predicates.KiePMMLSimpleSetPredicate)8 CompoundPredicate (org.dmg.pmml.CompoundPredicate)7 Collectors (java.util.stream.Collectors)6 ArrayList (java.util.ArrayList)5 Arrays (java.util.Arrays)5 Collections (java.util.Collections)5 DataField (org.dmg.pmml.DataField)5 FieldName (org.dmg.pmml.FieldName)5 Predicate (org.dmg.pmml.Predicate)5 KiePMMLCompoundPredicate (org.kie.pmml.commons.model.predicates.KiePMMLCompoundPredicate)5 KiePMMLSimplePredicate (org.kie.pmml.commons.model.predicates.KiePMMLSimplePredicate)5 PMMLModelTestUtils.getSimplePredicate (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getSimplePredicate)5 KiePMMLSimpleSetPredicateFactoryTest.getSimpleSetPredicate (org.kie.pmml.compiler.commons.codegenfactories.KiePMMLSimpleSetPredicateFactoryTest.getSimpleSetPredicate)5 KiePMMLDroolsRule (org.kie.pmml.models.drools.ast.KiePMMLDroolsRule)5 BlockStmt (com.github.javaparser.ast.stmt.BlockStmt)4