Search in sources :

Example 26 with SimplePredicate

use of org.dmg.pmml.SimplePredicate in project drools by kiegroup.

the class KiePMMLSimplePredicateFactoryTest method getSimplePredicateVariableDeclaration.

@Test
public void getSimplePredicateVariableDeclaration() throws IOException {
    String variableName = "variableName";
    final SimplePredicate simplePredicate = new SimplePredicate();
    simplePredicate.setField(FieldName.create("CUSTOM_FIELD"));
    simplePredicate.setValue("235.435");
    simplePredicate.setOperator(SimplePredicate.Operator.EQUAL);
    String operatorString = OPERATOR.class.getName() + "." + OPERATOR.byName(simplePredicate.getOperator().value());
    DataField dataField = new DataField();
    dataField.setName(simplePredicate.getField());
    dataField.setDataType(DataType.DOUBLE);
    DataDictionary dataDictionary = new DataDictionary();
    dataDictionary.addDataFields(dataField);
    BlockStmt retrieved = KiePMMLSimplePredicateFactory.getSimplePredicateVariableDeclaration(variableName, simplePredicate, getFieldsFromDataDictionary(dataDictionary));
    String text = getFileContent(TEST_01_SOURCE);
    Statement expected = JavaParserUtils.parseBlock(String.format(text, variableName, simplePredicate.getField().getValue(), operatorString, simplePredicate.getValue()));
    assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
    List<Class<?>> imports = Arrays.asList(KiePMMLSimplePredicate.class, Collections.class);
    commonValidateCompilationWithImports(retrieved, imports);
}
Also used : DataField(org.dmg.pmml.DataField) Statement(com.github.javaparser.ast.stmt.Statement) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) DataDictionary(org.dmg.pmml.DataDictionary) CommonTestingUtils.getFieldsFromDataDictionary(org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary) KiePMMLSimplePredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimplePredicate) SimplePredicate(org.dmg.pmml.SimplePredicate) Test(org.junit.Test)

Example 27 with SimplePredicate

use of org.dmg.pmml.SimplePredicate in project drools by kiegroup.

the class KiePMMLASTFactoryUtilsTest method getConstraintEntryFromSimplePredicates.

@Test
public void getConstraintEntryFromSimplePredicates() {
    final Map<String, KiePMMLOriginalTypeGeneratedType> fieldTypeMap = new HashMap<>();
    String fieldName = "FIELD_NAME";
    List<SimplePredicate> simplePredicates = IntStream.range(0, 2).mapToObj(index -> {
        fieldTypeMap.put(fieldName, new KiePMMLOriginalTypeGeneratedType(DataType.STRING.value(), getSanitizedClassName(fieldName.toUpperCase())));
        return PMMLModelTestUtils.getSimplePredicate(fieldName, "VALUE-" + index, SimplePredicate.Operator.LESS_THAN);
    }).collect(Collectors.toList());
    final KiePMMLFieldOperatorValue retrieved = KiePMMLASTFactoryUtils.getConstraintEntryFromSimplePredicates(fieldName, BOOLEAN_OPERATOR.OR, simplePredicates, fieldTypeMap);
    assertEquals(fieldName, retrieved.getName());
    assertNotNull(retrieved.getConstraintsAsString());
    String expected = "value < \"VALUE-0\" || value < \"VALUE-1\"";
    assertEquals(expected, retrieved.getConstraintsAsString());
}
Also used : IntStream(java.util.stream.IntStream) Predicate(org.dmg.pmml.Predicate) BeforeClass(org.junit.BeforeClass) KiePMMLOriginalTypeGeneratedType(org.kie.pmml.models.drools.tuples.KiePMMLOriginalTypeGeneratedType) Collectors.groupingBy(java.util.stream.Collectors.groupingBy) KiePMMLFieldOperatorValue(org.kie.pmml.models.drools.ast.KiePMMLFieldOperatorValue) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) PMMLModelTestUtils(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils) KiePMMLOperatorValue(org.kie.pmml.models.drools.tuples.KiePMMLOperatorValue) Map(java.util.Map) CompoundPredicate(org.dmg.pmml.CompoundPredicate) Assert.assertNotNull(org.junit.Assert.assertNotNull) DataType(org.dmg.pmml.DataType) BOOLEAN_OPERATOR(org.kie.pmml.api.enums.BOOLEAN_OPERATOR) Assert.assertTrue(org.junit.Assert.assertTrue) Test(org.junit.Test) Collectors(java.util.stream.Collectors) OPERATOR(org.kie.pmml.api.enums.OPERATOR) Consumer(java.util.function.Consumer) List(java.util.List) SimplePredicate(org.dmg.pmml.SimplePredicate) Assert.assertFalse(org.junit.Assert.assertFalse) PMMLModelTestUtils.getRandomSimplePredicateOperator(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomSimplePredicateOperator) KiePMMLException(org.kie.pmml.api.exceptions.KiePMMLException) Assert.assertEquals(org.junit.Assert.assertEquals) KiePMMLModelUtils.getSanitizedClassName(org.kie.pmml.commons.utils.KiePMMLModelUtils.getSanitizedClassName) PMMLModelTestUtils.getRandomObject(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomObject) HashMap(java.util.HashMap) KiePMMLFieldOperatorValue(org.kie.pmml.models.drools.ast.KiePMMLFieldOperatorValue) SimplePredicate(org.dmg.pmml.SimplePredicate) KiePMMLOriginalTypeGeneratedType(org.kie.pmml.models.drools.tuples.KiePMMLOriginalTypeGeneratedType) Test(org.junit.Test)

Example 28 with SimplePredicate

use of org.dmg.pmml.SimplePredicate in project drools by kiegroup.

the class KiePMMLCompoundPredicateFactoryTest method getCompoundPredicateVariableDeclaration.

@Test
public void getCompoundPredicateVariableDeclaration() throws IOException {
    String variableName = "variableName";
    SimplePredicate simplePredicate1 = getSimplePredicate(PARAM_1, value1, operator1);
    SimplePredicate simplePredicate2 = getSimplePredicate(PARAM_2, value2, operator2);
    Array.Type arrayType = Array.Type.STRING;
    List<String> values = getStringObjects(arrayType, 4);
    SimpleSetPredicate simpleSetPredicate = getSimpleSetPredicate(values, arrayType, SimpleSetPredicate.BooleanOperator.IS_IN);
    CompoundPredicate compoundPredicate = new CompoundPredicate();
    compoundPredicate.setBooleanOperator(CompoundPredicate.BooleanOperator.AND);
    compoundPredicate.getPredicates().add(0, simplePredicate1);
    compoundPredicate.getPredicates().add(1, simplePredicate2);
    compoundPredicate.getPredicates().add(2, simpleSetPredicate);
    DataField dataField1 = new DataField();
    dataField1.setName(simplePredicate1.getField());
    dataField1.setDataType(DataType.DOUBLE);
    DataField dataField2 = new DataField();
    dataField2.setName(simplePredicate2.getField());
    dataField2.setDataType(DataType.DOUBLE);
    DataField dataField3 = new DataField();
    dataField3.setName(simpleSetPredicate.getField());
    dataField3.setDataType(DataType.DOUBLE);
    DataDictionary dataDictionary = new DataDictionary();
    dataDictionary.addDataFields(dataField1, dataField2, dataField3);
    String booleanOperatorString = BOOLEAN_OPERATOR.class.getName() + "." + BOOLEAN_OPERATOR.byName(compoundPredicate.getBooleanOperator().value()).name();
    String valuesString = values.stream().map(valueString -> "\"" + valueString + "\"").collect(Collectors.joining(","));
    final List<Field<?>> fields = getFieldsFromDataDictionary(dataDictionary);
    BlockStmt retrieved = KiePMMLCompoundPredicateFactory.getCompoundPredicateVariableDeclaration(variableName, compoundPredicate, fields);
    String text = getFileContent(TEST_01_SOURCE);
    Statement expected = JavaParserUtils.parseBlock(String.format(text, variableName, valuesString, booleanOperatorString));
    assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
    List<Class<?>> imports = Arrays.asList(KiePMMLCompoundPredicate.class, KiePMMLSimplePredicate.class, KiePMMLSimpleSetPredicate.class, Arrays.class, Collections.class);
    commonValidateCompilationWithImports(retrieved, imports);
}
Also used : Arrays(java.util.Arrays) PMMLModelTestUtils.getSimplePredicate(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getSimplePredicate) Field(org.dmg.pmml.Field) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate) CompoundPredicate(org.dmg.pmml.CompoundPredicate) JavaParserUtils(org.kie.pmml.compiler.commons.utils.JavaParserUtils) KiePMMLSimplePredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimplePredicate) DataType(org.dmg.pmml.DataType) KiePMMLSimpleSetPredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimpleSetPredicate) BOOLEAN_OPERATOR(org.kie.pmml.api.enums.BOOLEAN_OPERATOR) KiePMMLCompoundPredicate(org.kie.pmml.commons.model.predicates.KiePMMLCompoundPredicate) Assert.assertTrue(org.junit.Assert.assertTrue) IOException(java.io.IOException) DataDictionary(org.dmg.pmml.DataDictionary) Test(org.junit.Test) Statement(com.github.javaparser.ast.stmt.Statement) CommonTestingUtils.getFieldsFromDataDictionary(org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary) Collectors(java.util.stream.Collectors) Array(org.dmg.pmml.Array) FileUtils.getFileContent(org.kie.test.util.filesystem.FileUtils.getFileContent) DataField(org.dmg.pmml.DataField) List(java.util.List) SimplePredicate(org.dmg.pmml.SimplePredicate) PMMLModelTestUtils.getStringObjects(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getStringObjects) CodegenTestUtils.commonValidateCompilationWithImports(org.kie.pmml.compiler.commons.testutils.CodegenTestUtils.commonValidateCompilationWithImports) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) KiePMMLSimpleSetPredicateFactoryTest.getSimpleSetPredicate(org.kie.pmml.compiler.commons.codegenfactories.KiePMMLSimpleSetPredicateFactoryTest.getSimpleSetPredicate) Collections(java.util.Collections) Statement(com.github.javaparser.ast.stmt.Statement) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) DataDictionary(org.dmg.pmml.DataDictionary) CommonTestingUtils.getFieldsFromDataDictionary(org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary) PMMLModelTestUtils.getSimplePredicate(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getSimplePredicate) KiePMMLSimplePredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimplePredicate) SimplePredicate(org.dmg.pmml.SimplePredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate) KiePMMLSimpleSetPredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimpleSetPredicate) KiePMMLSimpleSetPredicateFactoryTest.getSimpleSetPredicate(org.kie.pmml.compiler.commons.codegenfactories.KiePMMLSimpleSetPredicateFactoryTest.getSimpleSetPredicate) Array(org.dmg.pmml.Array) Field(org.dmg.pmml.Field) DataField(org.dmg.pmml.DataField) DataField(org.dmg.pmml.DataField) CompoundPredicate(org.dmg.pmml.CompoundPredicate) KiePMMLCompoundPredicate(org.kie.pmml.commons.model.predicates.KiePMMLCompoundPredicate) Test(org.junit.Test)

Example 29 with SimplePredicate

use of org.dmg.pmml.SimplePredicate in project drools by kiegroup.

the class PMMLModelTestUtils method getSimplePredicate.

public static SimplePredicate getSimplePredicate(final String predicateName, final Object value, final SimplePredicate.Operator operator) {
    FieldName fieldName = FieldName.create(predicateName);
    SimplePredicate toReturn = new SimplePredicate();
    toReturn.setField(fieldName);
    toReturn.setOperator(operator);
    toReturn.setValue(value);
    return toReturn;
}
Also used : FieldName(org.dmg.pmml.FieldName) SimplePredicate(org.dmg.pmml.SimplePredicate)

Example 30 with SimplePredicate

use of org.dmg.pmml.SimplePredicate in project jpmml-r by jpmml.

the class ScorecardConverter method encodeModel.

@Override
public Scorecard encodeModel(Schema schema) {
    RGenericVector glm = getObject();
    RDoubleVector coefficients = glm.getDoubleElement("coefficients");
    RGenericVector family = glm.getGenericElement("family");
    RGenericVector scConf = DecorationUtil.getGenericElement(glm, "sc.conf");
    Double intercept = coefficients.getElement(LMConverter.INTERCEPT, false);
    List<? extends Feature> features = schema.getFeatures();
    SchemaUtil.checkSize(coefficients.size() - (intercept != null ? 1 : 0), features);
    RNumberVector<?> odds = scConf.getNumericElement("odds");
    RNumberVector<?> basePoints = scConf.getNumericElement("base_points");
    RNumberVector<?> pdo = scConf.getNumericElement("pdo");
    double factor = (pdo.asScalar()).doubleValue() / Math.log(2);
    Map<String, Characteristic> fieldCharacteristics = new LinkedHashMap<>();
    for (Feature feature : features) {
        String name = feature.getName();
        if (!(feature instanceof BinaryFeature)) {
            throw new IllegalArgumentException();
        }
        Double coefficient = getFeatureCoefficient(feature, coefficients);
        Characteristic characteristic = fieldCharacteristics.get(name);
        if (characteristic == null) {
            characteristic = new Characteristic().setName("score(" + FeatureUtil.getName(feature) + ")");
            fieldCharacteristics.put(name, characteristic);
        }
        BinaryFeature binaryFeature = (BinaryFeature) feature;
        SimplePredicate simplePredicate = new SimplePredicate(binaryFeature.getName(), SimplePredicate.Operator.EQUAL, binaryFeature.getValue());
        Attribute attribute = new Attribute(simplePredicate).setPartialScore(formatScore(-1d * coefficient * factor));
        characteristic.addAttributes(attribute);
    }
    Characteristics characteristics = new Characteristics();
    Collection<Map.Entry<String, Characteristic>> entries = fieldCharacteristics.entrySet();
    for (Map.Entry<String, Characteristic> entry : entries) {
        Characteristic characteristic = entry.getValue();
        Attribute attribute = new Attribute(True.INSTANCE).setPartialScore(0d);
        characteristic.addAttributes(attribute);
        characteristics.addCharacteristics(characteristic);
    }
    Scorecard scorecard = new Scorecard(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), characteristics).setInitialScore(formatScore((basePoints.asScalar()).doubleValue() - Math.log((odds.asScalar()).doubleValue()) * factor - (intercept != null ? intercept * factor : 0))).setUseReasonCodes(false);
    return scorecard;
}
Also used : Attribute(org.dmg.pmml.scorecard.Attribute) Characteristic(org.dmg.pmml.scorecard.Characteristic) BinaryFeature(org.jpmml.converter.BinaryFeature) Feature(org.jpmml.converter.Feature) BinaryFeature(org.jpmml.converter.BinaryFeature) SimplePredicate(org.dmg.pmml.SimplePredicate) LinkedHashMap(java.util.LinkedHashMap) Characteristics(org.dmg.pmml.scorecard.Characteristics) Scorecard(org.dmg.pmml.scorecard.Scorecard) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map)

Aggregations

SimplePredicate (org.dmg.pmml.SimplePredicate)30 Test (org.junit.Test)17 CompoundPredicate (org.dmg.pmml.CompoundPredicate)15 ArrayList (java.util.ArrayList)11 KiePMMLDroolsRule (org.kie.pmml.models.drools.ast.KiePMMLDroolsRule)11 HashMap (java.util.HashMap)10 List (java.util.List)10 KiePMMLOriginalTypeGeneratedType (org.kie.pmml.models.drools.tuples.KiePMMLOriginalTypeGeneratedType)10 Collectors (java.util.stream.Collectors)9 SimpleSetPredicate (org.dmg.pmml.SimpleSetPredicate)9 Predicate (org.dmg.pmml.Predicate)8 KiePMMLSimplePredicate (org.kie.pmml.commons.model.predicates.KiePMMLSimplePredicate)7 KiePMMLFieldOperatorValue (org.kie.pmml.models.drools.ast.KiePMMLFieldOperatorValue)7 DataField (org.dmg.pmml.DataField)6 DataType (org.dmg.pmml.DataType)6 KiePMMLASTTestUtils.getPredicateASTFactoryData (org.kie.pmml.models.drools.utils.KiePMMLASTTestUtils.getPredicateASTFactoryData)6 Map (java.util.Map)5 Assert.assertTrue (org.junit.Assert.assertTrue)5 KiePMMLCompoundPredicate (org.kie.pmml.commons.model.predicates.KiePMMLCompoundPredicate)5 KiePMMLSimpleSetPredicate (org.kie.pmml.commons.model.predicates.KiePMMLSimpleSetPredicate)5