Search in sources :

Example 1 with PredictorTerm

use of org.dmg.pmml.regression.PredictorTerm in project drools by kiegroup.

the class KiePMMLRegressionTableFactory method getPredictorTermBody.

/**
 * Add a <b>PredictorTerm</b> <code>MethodDeclaration</code> to the class
 * @param predictorTerm
 * @return
 */
static BlockStmt getPredictorTermBody(final PredictorTerm predictorTerm) {
    try {
        templateEvaluate = getFromFileName(KIE_PMML_EVALUATE_METHOD_TEMPLATE_JAVA);
        cloneEvaluate = templateEvaluate.clone();
        ClassOrInterfaceDeclaration evaluateTemplateClass = cloneEvaluate.getClassByName(KIE_PMML_EVALUATE_METHOD_TEMPLATE).orElseThrow(() -> new RuntimeException(MAIN_CLASS_NOT_FOUND));
        MethodDeclaration methodTemplate = evaluateTemplateClass.getMethodsByName("evaluatePredictor").get(0);
        final BlockStmt body = methodTemplate.getBody().orElseThrow(() -> new KiePMMLInternalException(String.format(MISSING_BODY_TEMPLATE, methodTemplate.getName())));
        VariableDeclarator variableDeclarator = getVariableDeclarator(body, "fieldRefs").orElseThrow(() -> new KiePMMLInternalException(String.format(MISSING_VARIABLE_IN_BODY, "fieldRefs", body)));
        final List<Expression> nodeList = predictorTerm.getFieldRefs().stream().map(fieldRef -> new StringLiteralExpr(fieldRef.getField().getValue())).collect(Collectors.toList());
        NodeList<Expression> expressions = NodeList.nodeList(nodeList);
        MethodCallExpr methodCallExpr = new MethodCallExpr(new NameExpr("Arrays"), "asList", expressions);
        variableDeclarator.setInitializer(methodCallExpr);
        variableDeclarator = getVariableDeclarator(body, COEFFICIENT).orElseThrow(() -> new KiePMMLInternalException(String.format(MISSING_VARIABLE_IN_BODY, COEFFICIENT, body)));
        variableDeclarator.setInitializer(String.valueOf(predictorTerm.getCoefficient().doubleValue()));
        return methodTemplate.getBody().orElseThrow(() -> new KiePMMLInternalException(String.format(MISSING_BODY_TEMPLATE, methodTemplate.getName())));
    } catch (Exception e) {
        throw new KiePMMLInternalException(String.format("Failed to add PredictorTerm %s", predictorTerm), e);
    }
}
Also used : Arrays(java.util.Arrays) ClassOrInterfaceType(com.github.javaparser.ast.type.ClassOrInterfaceType) CommonCodegenUtils.createPopulatedHashMap(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.createPopulatedHashMap) KiePMMLModelUtils.getSanitizedVariableName(org.kie.pmml.commons.utils.KiePMMLModelUtils.getSanitizedVariableName) LoggerFactory(org.slf4j.LoggerFactory) MISSING_VARIABLE_IN_BODY(org.kie.pmml.commons.Constants.MISSING_VARIABLE_IN_BODY) MAIN_CLASS_NOT_FOUND(org.kie.pmml.compiler.commons.utils.JavaParserUtils.MAIN_CLASS_NOT_FOUND) NullLiteralExpr(com.github.javaparser.ast.expr.NullLiteralExpr) CategoricalPredictor(org.dmg.pmml.regression.CategoricalPredictor) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ObjectCreationExpr(com.github.javaparser.ast.expr.ObjectCreationExpr) Map(java.util.Map) Expression(com.github.javaparser.ast.expr.Expression) CompilationUnit(com.github.javaparser.ast.CompilationUnit) NodeList(com.github.javaparser.ast.NodeList) UnknownType(com.github.javaparser.ast.type.UnknownType) CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypeNames(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypeNames) RegressionModel(org.dmg.pmml.regression.RegressionModel) MethodCallExpr(com.github.javaparser.ast.expr.MethodCallExpr) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) JavaParserUtils.getFullClassName(org.kie.pmml.compiler.commons.utils.JavaParserUtils.getFullClassName) MISSING_BODY_TEMPLATE(org.kie.pmml.commons.Constants.MISSING_BODY_TEMPLATE) UUID(java.util.UUID) RegressionTable(org.dmg.pmml.regression.RegressionTable) Collectors(java.util.stream.Collectors) StringLiteralExpr(com.github.javaparser.ast.expr.StringLiteralExpr) VariableDeclarationExpr(com.github.javaparser.ast.expr.VariableDeclarationExpr) Objects(java.util.Objects) List(java.util.List) CommonCodegenUtils.getExpressionForObject(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.getExpressionForObject) SerializableFunction(org.kie.pmml.api.iinterfaces.SerializableFunction) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) KiePMMLRegressionTable(org.kie.pmml.models.regression.model.KiePMMLRegressionTable) ClassOrInterfaceDeclaration(com.github.javaparser.ast.body.ClassOrInterfaceDeclaration) TO_RETURN(org.kie.pmml.commons.Constants.TO_RETURN) VARIABLE_NAME_TEMPLATE(org.kie.pmml.commons.Constants.VARIABLE_NAME_TEMPLATE) LambdaExpr(com.github.javaparser.ast.expr.LambdaExpr) ExpressionStmt(com.github.javaparser.ast.stmt.ExpressionStmt) JavaParserUtils.getFromFileName(org.kie.pmml.compiler.commons.utils.JavaParserUtils.getFromFileName) Parameter(com.github.javaparser.ast.body.Parameter) Collectors.groupingBy(java.util.stream.Collectors.groupingBy) HashMap(java.util.HashMap) CastExpr(com.github.javaparser.ast.expr.CastExpr) AtomicReference(java.util.concurrent.atomic.AtomicReference) PredictorTerm(org.dmg.pmml.regression.PredictorTerm) LinkedHashMap(java.util.LinkedHashMap) VariableDeclarator(com.github.javaparser.ast.body.VariableDeclarator) KiePMMLInternalException(org.kie.pmml.api.exceptions.KiePMMLInternalException) CommonCodegenUtils.addMapPopulationExpressions(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.addMapPopulationExpressions) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) AbstractKiePMMLTable(org.kie.pmml.models.regression.model.AbstractKiePMMLTable) JavaParserUtils(org.kie.pmml.compiler.commons.utils.JavaParserUtils) Logger(org.slf4j.Logger) CommonCodegenUtils.getChainedMethodCallExprFrom(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.getChainedMethodCallExprFrom) NameExpr(com.github.javaparser.ast.expr.NameExpr) MethodReferenceExpr(com.github.javaparser.ast.expr.MethodReferenceExpr) CommonCodegenUtils.getVariableDeclarator(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.getVariableDeclarator) KiePMMLTableSourceCategory(org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory) AbstractMap(java.util.AbstractMap) MethodDeclaration(com.github.javaparser.ast.body.MethodDeclaration) MISSING_VARIABLE_INITIALIZER_TEMPLATE(org.kie.pmml.commons.Constants.MISSING_VARIABLE_INITIALIZER_TEMPLATE) Collections(java.util.Collections) KiePMMLException(org.kie.pmml.api.exceptions.KiePMMLException) ClassOrInterfaceDeclaration(com.github.javaparser.ast.body.ClassOrInterfaceDeclaration) MethodDeclaration(com.github.javaparser.ast.body.MethodDeclaration) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) StringLiteralExpr(com.github.javaparser.ast.expr.StringLiteralExpr) NameExpr(com.github.javaparser.ast.expr.NameExpr) KiePMMLInternalException(org.kie.pmml.api.exceptions.KiePMMLInternalException) KiePMMLException(org.kie.pmml.api.exceptions.KiePMMLException) VariableDeclarator(com.github.javaparser.ast.body.VariableDeclarator) CommonCodegenUtils.getVariableDeclarator(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.getVariableDeclarator) Expression(com.github.javaparser.ast.expr.Expression) KiePMMLInternalException(org.kie.pmml.api.exceptions.KiePMMLInternalException) MethodCallExpr(com.github.javaparser.ast.expr.MethodCallExpr)

Example 2 with PredictorTerm

use of org.dmg.pmml.regression.PredictorTerm in project drools by kiegroup.

the class KiePMMLRegressionModelFactoryTest method evaluateRegressionTable.

private void evaluateRegressionTable(KiePMMLRegressionTable regressionTable, RegressionTable originalRegressionTable) {
    assertEquals(originalRegressionTable.getIntercept(), regressionTable.getIntercept());
    final Map<String, SerializableFunction<Double, Double>> numericFunctionMap = regressionTable.getNumericFunctionMap();
    for (NumericPredictor numericPredictor : originalRegressionTable.getNumericPredictors()) {
        assertTrue(numericFunctionMap.containsKey(numericPredictor.getName().getValue()));
    }
    final Map<String, SerializableFunction<String, Double>> categoricalFunctionMap = regressionTable.getCategoricalFunctionMap();
    for (CategoricalPredictor categoricalPredictor : originalRegressionTable.getCategoricalPredictors()) {
        assertTrue(categoricalFunctionMap.containsKey(categoricalPredictor.getName().getValue()));
    }
    final Map<String, SerializableFunction<Map<String, Object>, Double>> predictorTermsFunctionMap = regressionTable.getPredictorTermsFunctionMap();
    for (PredictorTerm predictorTerm : originalRegressionTable.getPredictorTerms()) {
        assertTrue(predictorTermsFunctionMap.containsKey(predictorTerm.getName().getValue()));
    }
}
Also used : CategoricalPredictor(org.dmg.pmml.regression.CategoricalPredictor) PMMLModelTestUtils.getCategoricalPredictor(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getCategoricalPredictor) PredictorTerm(org.dmg.pmml.regression.PredictorTerm) PMMLModelTestUtils.getPredictorTerm(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getPredictorTerm) SerializableFunction(org.kie.pmml.api.iinterfaces.SerializableFunction) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) PMMLModelTestUtils.getNumericPredictor(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getNumericPredictor)

Example 3 with PredictorTerm

use of org.dmg.pmml.regression.PredictorTerm in project drools by kiegroup.

the class KiePMMLRegressionTableFactoryTest method getPredictorTermFunctions.

@Test
public void getPredictorTermFunctions() {
    final List<PredictorTerm> predictorTerms = IntStream.range(0, 3).mapToObj(index -> {
        String predictorName = "predictorName-" + index;
        double coefficient = 1.23 * index;
        String fieldRef = "fieldRef-" + index;
        return PMMLModelTestUtils.getPredictorTerm(predictorName, coefficient, Collections.singletonList(fieldRef));
    }).collect(Collectors.toList());
    Map<String, Expression> retrieved = KiePMMLRegressionTableFactory.getPredictorTermFunctions(predictorTerms);
    assertEquals(predictorTerms.size(), retrieved.size());
    IntStream.range(0, predictorTerms.size()).forEach(index -> {
        PredictorTerm predictorTerm = predictorTerms.get(index);
        assertTrue(retrieved.containsKey(predictorTerm.getName().getValue()));
    });
}
Also used : RESULT_FEATURE(org.kie.pmml.api.enums.RESULT_FEATURE) Arrays(java.util.Arrays) KiePMMLModelUtils.getSanitizedVariableName(org.kie.pmml.commons.utils.KiePMMLModelUtils.getSanitizedVariableName) MiningSchema(org.dmg.pmml.MiningSchema) KiePMMLOutputField(org.kie.pmml.commons.model.KiePMMLOutputField) FieldName(org.dmg.pmml.FieldName) OpType(org.dmg.pmml.OpType) NullLiteralExpr(com.github.javaparser.ast.expr.NullLiteralExpr) CategoricalPredictor(org.dmg.pmml.regression.CategoricalPredictor) GETKIEPMML_TABLE(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.GETKIEPMML_TABLE) TestCase.assertNotNull(junit.framework.TestCase.assertNotNull) Map(java.util.Map) Expression(com.github.javaparser.ast.expr.Expression) Assert.fail(org.junit.Assert.fail) CompilationUnit(com.github.javaparser.ast.CompilationUnit) PMML(org.dmg.pmml.PMML) RegressionModel(org.dmg.pmml.regression.RegressionModel) KIE_PMML_REGRESSION_TABLE_TEMPLATE_JAVA(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.KIE_PMML_REGRESSION_TABLE_TEMPLATE_JAVA) MethodCallExpr(com.github.javaparser.ast.expr.MethodCallExpr) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) RegressionTable(org.dmg.pmml.regression.RegressionTable) Collectors(java.util.stream.Collectors) FileUtils.getFileContent(org.kie.test.util.filesystem.FileUtils.getFileContent) DataField(org.dmg.pmml.DataField) SUPPORTED_NORMALIZATION_METHODS(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.SUPPORTED_NORMALIZATION_METHODS) List(java.util.List) CommonCompilationDTO(org.kie.pmml.compiler.api.dto.CommonCompilationDTO) SerializableFunction(org.kie.pmml.api.iinterfaces.SerializableFunction) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) KiePMMLRegressionTable(org.kie.pmml.models.regression.model.KiePMMLRegressionTable) ClassOrInterfaceDeclaration(com.github.javaparser.ast.body.ClassOrInterfaceDeclaration) IntStream(java.util.stream.IntStream) LambdaExpr(com.github.javaparser.ast.expr.LambdaExpr) ExpressionStmt(com.github.javaparser.ast.stmt.ExpressionStmt) BeforeClass(org.junit.BeforeClass) JavaParserUtils.getFromFileName(org.kie.pmml.compiler.commons.utils.JavaParserUtils.getFromFileName) Collectors.groupingBy(java.util.stream.Collectors.groupingBy) HashMap(java.util.HashMap) CastExpr(com.github.javaparser.ast.expr.CastExpr) AtomicReference(java.util.concurrent.atomic.AtomicReference) KIE_PMML_REGRESSION_TABLE_TEMPLATE(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.KIE_PMML_REGRESSION_TABLE_TEMPLATE) ArrayList(java.util.ArrayList) PredictorTerm(org.dmg.pmml.regression.PredictorTerm) PMMLModelTestUtils(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils) MiningField(org.dmg.pmml.MiningField) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) JavaParserUtils(org.kie.pmml.compiler.commons.utils.JavaParserUtils) Assert.assertTrue(org.junit.Assert.assertTrue) IOException(java.io.IOException) DataDictionary(org.dmg.pmml.DataDictionary) Test(org.junit.Test) CodegenTestUtils.commonValidateCompilation(org.kie.pmml.compiler.commons.testutils.CodegenTestUtils.commonValidateCompilation) KiePMMLModelUtils.getGeneratedClassName(org.kie.pmml.commons.utils.KiePMMLModelUtils.getGeneratedClassName) MethodReferenceExpr(com.github.javaparser.ast.expr.MethodReferenceExpr) KiePMMLTableSourceCategory(org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory) Assert.assertNull(org.junit.Assert.assertNull) MethodDeclaration(com.github.javaparser.ast.body.MethodDeclaration) CodegenTestUtils.commonValidateCompilationWithImports(org.kie.pmml.compiler.commons.testutils.CodegenTestUtils.commonValidateCompilationWithImports) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) UNSUPPORTED_NORMALIZATION_METHODS(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.UNSUPPORTED_NORMALIZATION_METHODS) Collections(java.util.Collections) Assert.assertEquals(org.junit.Assert.assertEquals) PredictorTerm(org.dmg.pmml.regression.PredictorTerm) Expression(com.github.javaparser.ast.expr.Expression) Test(org.junit.Test)

Example 4 with PredictorTerm

use of org.dmg.pmml.regression.PredictorTerm in project drools by kiegroup.

the class KiePMMLRegressionTableFactoryTest method getPredictorTermSerializableFunction.

@Test
public void getPredictorTermSerializableFunction() {
    String predictorName = "predictorName";
    double coefficient = 23.12;
    String fieldRef = "fieldRef";
    PredictorTerm predictorTerm = PMMLModelTestUtils.getPredictorTerm(predictorName, coefficient, Collections.singletonList(fieldRef));
    SerializableFunction<Map<String, Object>, Double> retrieved = KiePMMLRegressionTableFactory.getPredictorTermSerializableFunction(predictorTerm);
    assertNotNull(retrieved);
}
Also used : PredictorTerm(org.dmg.pmml.regression.PredictorTerm) Map(java.util.Map) HashMap(java.util.HashMap) Test(org.junit.Test)

Example 5 with PredictorTerm

use of org.dmg.pmml.regression.PredictorTerm in project drools by kiegroup.

the class PMMLModelTestUtils method getRegressionTable.

public static RegressionTable getRegressionTable(List<CategoricalPredictor> categoricalPredictors, List<NumericPredictor> numericPredictors, List<PredictorTerm> predictorTerms, double intercept, Object targetCategory) {
    RegressionTable toReturn = new RegressionTable();
    toReturn.setIntercept(intercept);
    toReturn.setTargetCategory(targetCategory);
    toReturn.addCategoricalPredictors(categoricalPredictors.toArray(new CategoricalPredictor[0]));
    toReturn.addNumericPredictors(numericPredictors.toArray(new NumericPredictor[0]));
    toReturn.addPredictorTerms(predictorTerms.toArray(new PredictorTerm[0]));
    return toReturn;
}
Also used : CategoricalPredictor(org.dmg.pmml.regression.CategoricalPredictor) PredictorTerm(org.dmg.pmml.regression.PredictorTerm) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) RegressionTable(org.dmg.pmml.regression.RegressionTable)

Aggregations

PredictorTerm (org.dmg.pmml.regression.PredictorTerm)9 CategoricalPredictor (org.dmg.pmml.regression.CategoricalPredictor)6 NumericPredictor (org.dmg.pmml.regression.NumericPredictor)6 Expression (com.github.javaparser.ast.expr.Expression)4 LambdaExpr (com.github.javaparser.ast.expr.LambdaExpr)4 HashMap (java.util.HashMap)4 Map (java.util.Map)4 RegressionTable (org.dmg.pmml.regression.RegressionTable)4 Test (org.junit.Test)4 CompilationUnit (com.github.javaparser.ast.CompilationUnit)3 ClassOrInterfaceDeclaration (com.github.javaparser.ast.body.ClassOrInterfaceDeclaration)3 MethodDeclaration (com.github.javaparser.ast.body.MethodDeclaration)3 CastExpr (com.github.javaparser.ast.expr.CastExpr)3 MethodCallExpr (com.github.javaparser.ast.expr.MethodCallExpr)3 MethodReferenceExpr (com.github.javaparser.ast.expr.MethodReferenceExpr)3 NullLiteralExpr (com.github.javaparser.ast.expr.NullLiteralExpr)3 BlockStmt (com.github.javaparser.ast.stmt.BlockStmt)3 ExpressionStmt (com.github.javaparser.ast.stmt.ExpressionStmt)3 Arrays (java.util.Arrays)3 Collections (java.util.Collections)3