Search in sources :

Example 11 with Attribute

use of org.dmg.pmml.scorecard.Attribute in project drools by kiegroup.

the class KiePMMLScorecardModelCharacteristicASTFactoryTest method declareRuleFromAttributeWithSimplePredicateLastCharacteristic.

@Test
public void declareRuleFromAttributeWithSimplePredicateLastCharacteristic() {
    Attribute attribute = getSimplePredicateAttribute();
    final String parentPath = "parent_path";
    final int attributeIndex = 2;
    final List<KiePMMLDroolsRule> rules = new ArrayList<>();
    final String statusToSet = "status_to_set";
    final String characteristicReasonCode = "REASON_CODE";
    final double characteristicBaselineScore = 12;
    final boolean isLastCharacteristic = true;
    getKiePMMLScorecardModelCharacteristicASTFactory().declareRuleFromAttribute(attribute, parentPath, attributeIndex, rules, statusToSet, characteristicReasonCode, characteristicBaselineScore, isLastCharacteristic);
    assertEquals(1, rules.size());
    commonValidateRule(rules.get(0), attribute, statusToSet, parentPath, attributeIndex, isLastCharacteristic, 1, null, BOOLEAN_OPERATOR.AND, "value <= 5.0", 1);
}
Also used : Attribute(org.dmg.pmml.scorecard.Attribute) ArrayList(java.util.ArrayList) KiePMMLDroolsRule(org.kie.pmml.models.drools.ast.KiePMMLDroolsRule) Test(org.junit.Test)

Example 12 with Attribute

use of org.dmg.pmml.scorecard.Attribute in project drools by kiegroup.

the class KiePMMLScorecardModelCharacteristicASTFactoryTest method getCompoundPredicateAttribute.

private Attribute getCompoundPredicateAttribute() {
    final double partialScore = 30.0;
    Attribute toReturn = new Attribute();
    toReturn.setPartialScore(partialScore);
    toReturn.setPredicate(getCompoundPredicate());
    return toReturn;
}
Also used : Attribute(org.dmg.pmml.scorecard.Attribute)

Example 13 with Attribute

use of org.dmg.pmml.scorecard.Attribute in project drools by kiegroup.

the class KiePMMLAttributeFactoryTest method getAttributeVariableDeclarationWithComplexPartialScore.

@Test
public void getAttributeVariableDeclarationWithComplexPartialScore() throws IOException {
    final String variableName = "variableName";
    Attribute attribute = new Attribute();
    attribute.setReasonCode(REASON_CODE);
    Array.Type arrayType = Array.Type.STRING;
    List<String> values = getStringObjects(arrayType, 4);
    CompoundPredicate compoundPredicate = getCompoundPredicate(values, arrayType);
    attribute.setPredicate(compoundPredicate);
    attribute.setComplexPartialScore(getComplexPartialScore());
    String valuesString = values.stream().map(valueString -> "\"" + valueString + "\"").collect(Collectors.joining(","));
    DataDictionary dataDictionary = new DataDictionary();
    for (Predicate predicate : compoundPredicate.getPredicates()) {
        DataField toAdd = null;
        if (predicate instanceof SimplePredicate) {
            toAdd = new DataField();
            toAdd.setName(((SimplePredicate) predicate).getField());
            toAdd.setDataType(DataType.DOUBLE);
        } else if (predicate instanceof SimpleSetPredicate) {
            toAdd = new DataField();
            toAdd.setName(((SimpleSetPredicate) predicate).getField());
            toAdd.setDataType(DataType.DOUBLE);
        }
        if (toAdd != null) {
            dataDictionary.addDataFields(toAdd);
        }
    }
    BlockStmt retrieved = KiePMMLAttributeFactory.getAttributeVariableDeclaration(variableName, attribute, getFieldsFromDataDictionary(dataDictionary));
    String text = getFileContent(TEST_01_SOURCE);
    Statement expected = JavaParserUtils.parseBlock(String.format(text, variableName, valuesString));
    assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
    List<Class<?>> imports = Arrays.asList(KiePMMLAttribute.class, KiePMMLComplexPartialScore.class, KiePMMLCompoundPredicate.class, KiePMMLConstant.class, KiePMMLSimplePredicate.class, KiePMMLSimpleSetPredicate.class, Arrays.class, Collections.class);
    commonValidateCompilationWithImports(retrieved, imports);
}
Also used : KiePMMLConstant(org.kie.pmml.commons.model.expressions.KiePMMLConstant) Arrays(java.util.Arrays) Predicate(org.dmg.pmml.Predicate) PMMLModelTestUtils.getSimplePredicate(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getSimplePredicate) KiePMMLAttribute(org.kie.pmml.models.scorecard.model.KiePMMLAttribute) KiePMMLComplexPartialScore(org.kie.pmml.models.scorecard.model.KiePMMLComplexPartialScore) ComplexPartialScore(org.dmg.pmml.scorecard.ComplexPartialScore) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate) CompoundPredicate(org.dmg.pmml.CompoundPredicate) JavaParserUtils(org.kie.pmml.compiler.commons.utils.JavaParserUtils) KiePMMLSimplePredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimplePredicate) DataType(org.dmg.pmml.DataType) KiePMMLSimpleSetPredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimpleSetPredicate) KiePMMLCompoundPredicate(org.kie.pmml.commons.model.predicates.KiePMMLCompoundPredicate) Assert.assertTrue(org.junit.Assert.assertTrue) IOException(java.io.IOException) DataDictionary(org.dmg.pmml.DataDictionary) Test(org.junit.Test) Statement(com.github.javaparser.ast.stmt.Statement) Attribute(org.dmg.pmml.scorecard.Attribute) CommonTestingUtils.getFieldsFromDataDictionary(org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary) Collectors(java.util.stream.Collectors) Array(org.dmg.pmml.Array) FileUtils.getFileContent(org.kie.test.util.filesystem.FileUtils.getFileContent) DataField(org.dmg.pmml.DataField) List(java.util.List) SimplePredicate(org.dmg.pmml.SimplePredicate) PMMLModelTestUtils.getStringObjects(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getStringObjects) CodegenTestUtils.commonValidateCompilationWithImports(org.kie.pmml.compiler.commons.testutils.CodegenTestUtils.commonValidateCompilationWithImports) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) KiePMMLSimpleSetPredicateFactoryTest.getSimpleSetPredicate(org.kie.pmml.compiler.commons.codegenfactories.KiePMMLSimpleSetPredicateFactoryTest.getSimpleSetPredicate) Constant(org.dmg.pmml.Constant) Collections(java.util.Collections) KiePMMLAttribute(org.kie.pmml.models.scorecard.model.KiePMMLAttribute) Attribute(org.dmg.pmml.scorecard.Attribute) Statement(com.github.javaparser.ast.stmt.Statement) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) DataDictionary(org.dmg.pmml.DataDictionary) CommonTestingUtils.getFieldsFromDataDictionary(org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary) PMMLModelTestUtils.getSimplePredicate(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getSimplePredicate) KiePMMLSimplePredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimplePredicate) SimplePredicate(org.dmg.pmml.SimplePredicate) Predicate(org.dmg.pmml.Predicate) PMMLModelTestUtils.getSimplePredicate(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getSimplePredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate) CompoundPredicate(org.dmg.pmml.CompoundPredicate) KiePMMLSimplePredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimplePredicate) KiePMMLSimpleSetPredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimpleSetPredicate) KiePMMLCompoundPredicate(org.kie.pmml.commons.model.predicates.KiePMMLCompoundPredicate) SimplePredicate(org.dmg.pmml.SimplePredicate) KiePMMLSimpleSetPredicateFactoryTest.getSimpleSetPredicate(org.kie.pmml.compiler.commons.codegenfactories.KiePMMLSimpleSetPredicateFactoryTest.getSimpleSetPredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate) KiePMMLSimpleSetPredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimpleSetPredicate) KiePMMLSimpleSetPredicateFactoryTest.getSimpleSetPredicate(org.kie.pmml.compiler.commons.codegenfactories.KiePMMLSimpleSetPredicateFactoryTest.getSimpleSetPredicate) Array(org.dmg.pmml.Array) DataField(org.dmg.pmml.DataField) CompoundPredicate(org.dmg.pmml.CompoundPredicate) KiePMMLCompoundPredicate(org.kie.pmml.commons.model.predicates.KiePMMLCompoundPredicate) Test(org.junit.Test)

Example 14 with Attribute

use of org.dmg.pmml.scorecard.Attribute in project drools by kiegroup.

the class KiePMMLCharacteristicFactory method getCharacteristicVariableDeclaration.

static BlockStmt getCharacteristicVariableDeclaration(final String variableName, final Characteristic characteristic, final List<Field<?>> fields) {
    final MethodDeclaration methodDeclaration = CHARACTERISTIC_TEMPLATE.getMethodsByName(GETKIEPMMLCHARACTERISTIC).get(0).clone();
    final BlockStmt characteristicBody = methodDeclaration.getBody().orElseThrow(() -> new KiePMMLException(String.format(MISSING_BODY_TEMPLATE, methodDeclaration)));
    final VariableDeclarator variableDeclarator = getVariableDeclarator(characteristicBody, CHARACTERISTIC).orElseThrow(() -> new KiePMMLException(String.format(MISSING_VARIABLE_IN_BODY, CHARACTERISTIC, characteristicBody)));
    variableDeclarator.setName(variableName);
    final BlockStmt toReturn = new BlockStmt();
    int counter = 0;
    NodeList<Expression> arguments = new NodeList<>();
    for (Attribute attribute : characteristic.getAttributes()) {
        String attributeVariableName = String.format(VARIABLE_NAME_TEMPLATE, variableName, counter);
        BlockStmt toAdd = getAttributeVariableDeclaration(attributeVariableName, attribute, fields);
        toAdd.getStatements().forEach(toReturn::addStatement);
        arguments.add(new NameExpr(attributeVariableName));
        counter++;
    }
    final MethodCallExpr initializer = variableDeclarator.getInitializer().orElseThrow(() -> new KiePMMLException(String.format(MISSING_VARIABLE_INITIALIZER_TEMPLATE, CHARACTERISTIC, characteristicBody))).asMethodCallExpr();
    final MethodCallExpr builder = getChainedMethodCallExprFrom("builder", initializer);
    builder.setArgument(0, new StringLiteralExpr(variableName));
    builder.setArgument(2, getArraysAsListInvocationMethodCall(arguments));
    getChainedMethodCallExprFrom("withBaselineScore", initializer).setArgument(0, getExpressionForObject(characteristic.getBaselineScore()));
    getChainedMethodCallExprFrom("withReasonCode", initializer).setArgument(0, getExpressionForObject(characteristic.getReasonCode()));
    characteristicBody.getStatements().forEach(toReturn::addStatement);
    return toReturn;
}
Also used : Attribute(org.dmg.pmml.scorecard.Attribute) MethodDeclaration(com.github.javaparser.ast.body.MethodDeclaration) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) NodeList(com.github.javaparser.ast.NodeList) NameExpr(com.github.javaparser.ast.expr.NameExpr) StringLiteralExpr(com.github.javaparser.ast.expr.StringLiteralExpr) VariableDeclarator(com.github.javaparser.ast.body.VariableDeclarator) CommonCodegenUtils.getVariableDeclarator(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.getVariableDeclarator) Expression(com.github.javaparser.ast.expr.Expression) KiePMMLException(org.kie.pmml.api.exceptions.KiePMMLException) MethodCallExpr(com.github.javaparser.ast.expr.MethodCallExpr)

Example 15 with Attribute

use of org.dmg.pmml.scorecard.Attribute in project jpmml-r by jpmml.

the class ScorecardConverter method encodeModel.

@Override
public Scorecard encodeModel(Schema schema) {
    RGenericVector glm = getObject();
    RDoubleVector coefficients = glm.getDoubleElement("coefficients");
    RGenericVector family = glm.getGenericElement("family");
    RGenericVector scConf = DecorationUtil.getGenericElement(glm, "sc.conf");
    Double intercept = coefficients.getElement(LMConverter.INTERCEPT, false);
    List<? extends Feature> features = schema.getFeatures();
    SchemaUtil.checkSize(coefficients.size() - (intercept != null ? 1 : 0), features);
    RNumberVector<?> odds = scConf.getNumericElement("odds");
    RNumberVector<?> basePoints = scConf.getNumericElement("base_points");
    RNumberVector<?> pdo = scConf.getNumericElement("pdo");
    double factor = (pdo.asScalar()).doubleValue() / Math.log(2);
    Map<String, Characteristic> fieldCharacteristics = new LinkedHashMap<>();
    for (Feature feature : features) {
        String name = feature.getName();
        if (!(feature instanceof BinaryFeature)) {
            throw new IllegalArgumentException();
        }
        Double coefficient = getFeatureCoefficient(feature, coefficients);
        Characteristic characteristic = fieldCharacteristics.get(name);
        if (characteristic == null) {
            characteristic = new Characteristic().setName("score(" + FeatureUtil.getName(feature) + ")");
            fieldCharacteristics.put(name, characteristic);
        }
        BinaryFeature binaryFeature = (BinaryFeature) feature;
        SimplePredicate simplePredicate = new SimplePredicate(binaryFeature.getName(), SimplePredicate.Operator.EQUAL, binaryFeature.getValue());
        Attribute attribute = new Attribute(simplePredicate).setPartialScore(formatScore(-1d * coefficient * factor));
        characteristic.addAttributes(attribute);
    }
    Characteristics characteristics = new Characteristics();
    Collection<Map.Entry<String, Characteristic>> entries = fieldCharacteristics.entrySet();
    for (Map.Entry<String, Characteristic> entry : entries) {
        Characteristic characteristic = entry.getValue();
        Attribute attribute = new Attribute(True.INSTANCE).setPartialScore(0d);
        characteristic.addAttributes(attribute);
        characteristics.addCharacteristics(characteristic);
    }
    Scorecard scorecard = new Scorecard(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), characteristics).setInitialScore(formatScore((basePoints.asScalar()).doubleValue() - Math.log((odds.asScalar()).doubleValue()) * factor - (intercept != null ? intercept * factor : 0))).setUseReasonCodes(false);
    return scorecard;
}
Also used : Attribute(org.dmg.pmml.scorecard.Attribute) Characteristic(org.dmg.pmml.scorecard.Characteristic) BinaryFeature(org.jpmml.converter.BinaryFeature) Feature(org.jpmml.converter.Feature) BinaryFeature(org.jpmml.converter.BinaryFeature) SimplePredicate(org.dmg.pmml.SimplePredicate) LinkedHashMap(java.util.LinkedHashMap) Characteristics(org.dmg.pmml.scorecard.Characteristics) Scorecard(org.dmg.pmml.scorecard.Scorecard) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map)

Aggregations

Attribute (org.dmg.pmml.scorecard.Attribute)15 Test (org.junit.Test)9 ArrayList (java.util.ArrayList)6 KiePMMLDroolsRule (org.kie.pmml.models.drools.ast.KiePMMLDroolsRule)6 SimplePredicate (org.dmg.pmml.SimplePredicate)4 BlockStmt (com.github.javaparser.ast.stmt.BlockStmt)3 Array (org.dmg.pmml.Array)3 CompoundPredicate (org.dmg.pmml.CompoundPredicate)3 SimpleSetPredicate (org.dmg.pmml.SimpleSetPredicate)3 Characteristic (org.dmg.pmml.scorecard.Characteristic)3 Statement (com.github.javaparser.ast.stmt.Statement)2 IOException (java.io.IOException)2 Arrays (java.util.Arrays)2 Collections (java.util.Collections)2 HashMap (java.util.HashMap)2 List (java.util.List)2 Collectors (java.util.stream.Collectors)2 Constant (org.dmg.pmml.Constant)2 DataDictionary (org.dmg.pmml.DataDictionary)2 DataField (org.dmg.pmml.DataField)2