Search in sources :

Example 1 with CategoricalPredictor

use of org.dmg.pmml.regression.CategoricalPredictor in project drools by kiegroup.

the class KiePMMLRegressionModelFactoryTest method evaluateRegressionTable.

private void evaluateRegressionTable(KiePMMLRegressionTable regressionTable, RegressionTable originalRegressionTable) {
    assertEquals(originalRegressionTable.getIntercept(), regressionTable.getIntercept());
    final Map<String, SerializableFunction<Double, Double>> numericFunctionMap = regressionTable.getNumericFunctionMap();
    for (NumericPredictor numericPredictor : originalRegressionTable.getNumericPredictors()) {
        assertTrue(numericFunctionMap.containsKey(numericPredictor.getName().getValue()));
    }
    final Map<String, SerializableFunction<String, Double>> categoricalFunctionMap = regressionTable.getCategoricalFunctionMap();
    for (CategoricalPredictor categoricalPredictor : originalRegressionTable.getCategoricalPredictors()) {
        assertTrue(categoricalFunctionMap.containsKey(categoricalPredictor.getName().getValue()));
    }
    final Map<String, SerializableFunction<Map<String, Object>, Double>> predictorTermsFunctionMap = regressionTable.getPredictorTermsFunctionMap();
    for (PredictorTerm predictorTerm : originalRegressionTable.getPredictorTerms()) {
        assertTrue(predictorTermsFunctionMap.containsKey(predictorTerm.getName().getValue()));
    }
}
Also used : CategoricalPredictor(org.dmg.pmml.regression.CategoricalPredictor) PMMLModelTestUtils.getCategoricalPredictor(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getCategoricalPredictor) PredictorTerm(org.dmg.pmml.regression.PredictorTerm) PMMLModelTestUtils.getPredictorTerm(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getPredictorTerm) SerializableFunction(org.kie.pmml.api.iinterfaces.SerializableFunction) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) PMMLModelTestUtils.getNumericPredictor(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getNumericPredictor)

Example 2 with CategoricalPredictor

use of org.dmg.pmml.regression.CategoricalPredictor in project drools by kiegroup.

the class KiePMMLRegressionTableFactoryTest method getCategoricalPredictorsExpressions.

@Test
public void getCategoricalPredictorsExpressions() {
    final List<CategoricalPredictor> categoricalPredictors = IntStream.range(0, 3).mapToObj(index -> IntStream.range(0, 3).mapToObj(i -> {
        String predictorName = "predictorName-" + index;
        double coefficient = 1.23 * i;
        return PMMLModelTestUtils.getCategoricalPredictor(predictorName, i, coefficient);
    }).collect(Collectors.toList())).reduce((categoricalPredictors1, categoricalPredictors2) -> {
        List<CategoricalPredictor> toReturn = new ArrayList<>();
        toReturn.addAll(categoricalPredictors1);
        toReturn.addAll(categoricalPredictors2);
        return toReturn;
    }).get();
    final BlockStmt body = new BlockStmt();
    Map<String, Expression> retrieved = KiePMMLRegressionTableFactory.getCategoricalPredictorsExpressions(categoricalPredictors, body, "variableName");
    assertEquals(3, retrieved.size());
    final Map<String, List<CategoricalPredictor>> groupedCollectors = categoricalPredictors.stream().collect(groupingBy(categoricalPredictor -> categoricalPredictor.getField().getValue()));
    groupedCollectors.values().forEach(categoricalPredictors12 -> commonEvaluateCategoryPredictors(body, categoricalPredictors12, "variableName"));
}
Also used : RESULT_FEATURE(org.kie.pmml.api.enums.RESULT_FEATURE) Arrays(java.util.Arrays) KiePMMLModelUtils.getSanitizedVariableName(org.kie.pmml.commons.utils.KiePMMLModelUtils.getSanitizedVariableName) MiningSchema(org.dmg.pmml.MiningSchema) KiePMMLOutputField(org.kie.pmml.commons.model.KiePMMLOutputField) FieldName(org.dmg.pmml.FieldName) OpType(org.dmg.pmml.OpType) NullLiteralExpr(com.github.javaparser.ast.expr.NullLiteralExpr) CategoricalPredictor(org.dmg.pmml.regression.CategoricalPredictor) GETKIEPMML_TABLE(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.GETKIEPMML_TABLE) TestCase.assertNotNull(junit.framework.TestCase.assertNotNull) Map(java.util.Map) Expression(com.github.javaparser.ast.expr.Expression) Assert.fail(org.junit.Assert.fail) CompilationUnit(com.github.javaparser.ast.CompilationUnit) PMML(org.dmg.pmml.PMML) RegressionModel(org.dmg.pmml.regression.RegressionModel) KIE_PMML_REGRESSION_TABLE_TEMPLATE_JAVA(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.KIE_PMML_REGRESSION_TABLE_TEMPLATE_JAVA) MethodCallExpr(com.github.javaparser.ast.expr.MethodCallExpr) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) RegressionTable(org.dmg.pmml.regression.RegressionTable) Collectors(java.util.stream.Collectors) FileUtils.getFileContent(org.kie.test.util.filesystem.FileUtils.getFileContent) DataField(org.dmg.pmml.DataField) SUPPORTED_NORMALIZATION_METHODS(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.SUPPORTED_NORMALIZATION_METHODS) List(java.util.List) CommonCompilationDTO(org.kie.pmml.compiler.api.dto.CommonCompilationDTO) SerializableFunction(org.kie.pmml.api.iinterfaces.SerializableFunction) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) KiePMMLRegressionTable(org.kie.pmml.models.regression.model.KiePMMLRegressionTable) ClassOrInterfaceDeclaration(com.github.javaparser.ast.body.ClassOrInterfaceDeclaration) IntStream(java.util.stream.IntStream) LambdaExpr(com.github.javaparser.ast.expr.LambdaExpr) ExpressionStmt(com.github.javaparser.ast.stmt.ExpressionStmt) BeforeClass(org.junit.BeforeClass) JavaParserUtils.getFromFileName(org.kie.pmml.compiler.commons.utils.JavaParserUtils.getFromFileName) Collectors.groupingBy(java.util.stream.Collectors.groupingBy) HashMap(java.util.HashMap) CastExpr(com.github.javaparser.ast.expr.CastExpr) AtomicReference(java.util.concurrent.atomic.AtomicReference) KIE_PMML_REGRESSION_TABLE_TEMPLATE(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.KIE_PMML_REGRESSION_TABLE_TEMPLATE) ArrayList(java.util.ArrayList) PredictorTerm(org.dmg.pmml.regression.PredictorTerm) PMMLModelTestUtils(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils) MiningField(org.dmg.pmml.MiningField) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) JavaParserUtils(org.kie.pmml.compiler.commons.utils.JavaParserUtils) Assert.assertTrue(org.junit.Assert.assertTrue) IOException(java.io.IOException) DataDictionary(org.dmg.pmml.DataDictionary) Test(org.junit.Test) CodegenTestUtils.commonValidateCompilation(org.kie.pmml.compiler.commons.testutils.CodegenTestUtils.commonValidateCompilation) KiePMMLModelUtils.getGeneratedClassName(org.kie.pmml.commons.utils.KiePMMLModelUtils.getGeneratedClassName) MethodReferenceExpr(com.github.javaparser.ast.expr.MethodReferenceExpr) KiePMMLTableSourceCategory(org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory) Assert.assertNull(org.junit.Assert.assertNull) MethodDeclaration(com.github.javaparser.ast.body.MethodDeclaration) CodegenTestUtils.commonValidateCompilationWithImports(org.kie.pmml.compiler.commons.testutils.CodegenTestUtils.commonValidateCompilationWithImports) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) UNSUPPORTED_NORMALIZATION_METHODS(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.UNSUPPORTED_NORMALIZATION_METHODS) Collections(java.util.Collections) Assert.assertEquals(org.junit.Assert.assertEquals) CategoricalPredictor(org.dmg.pmml.regression.CategoricalPredictor) Expression(com.github.javaparser.ast.expr.Expression) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) List(java.util.List) ArrayList(java.util.ArrayList) Test(org.junit.Test)

Example 3 with CategoricalPredictor

use of org.dmg.pmml.regression.CategoricalPredictor in project drools by kiegroup.

the class KiePMMLRegressionTableFactoryTest method commonEvaluateCategoryPredictors.

private void commonEvaluateCategoryPredictors(final BlockStmt toVerify, final List<CategoricalPredictor> categoricalPredictors, final String variableName) {
    for (int i = 0; i < categoricalPredictors.size(); i++) {
        CategoricalPredictor categoricalPredictor = categoricalPredictors.get(i);
        String expectedVariableName = getSanitizedVariableName(String.format("%sMap", variableName)) + "_" + i;
        assertTrue(toVerify.getStatements().stream().anyMatch(statement -> {
            String expected = String.format("%s.put(\"%s\", %s);", expectedVariableName, categoricalPredictor.getValue(), categoricalPredictor.getCoefficient());
            return statement instanceof ExpressionStmt && ((ExpressionStmt) statement).getExpression() instanceof MethodCallExpr && statement.toString().equals(expected);
        }));
    }
}
Also used : RESULT_FEATURE(org.kie.pmml.api.enums.RESULT_FEATURE) Arrays(java.util.Arrays) KiePMMLModelUtils.getSanitizedVariableName(org.kie.pmml.commons.utils.KiePMMLModelUtils.getSanitizedVariableName) MiningSchema(org.dmg.pmml.MiningSchema) KiePMMLOutputField(org.kie.pmml.commons.model.KiePMMLOutputField) FieldName(org.dmg.pmml.FieldName) OpType(org.dmg.pmml.OpType) NullLiteralExpr(com.github.javaparser.ast.expr.NullLiteralExpr) CategoricalPredictor(org.dmg.pmml.regression.CategoricalPredictor) GETKIEPMML_TABLE(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.GETKIEPMML_TABLE) TestCase.assertNotNull(junit.framework.TestCase.assertNotNull) Map(java.util.Map) Expression(com.github.javaparser.ast.expr.Expression) Assert.fail(org.junit.Assert.fail) CompilationUnit(com.github.javaparser.ast.CompilationUnit) PMML(org.dmg.pmml.PMML) RegressionModel(org.dmg.pmml.regression.RegressionModel) KIE_PMML_REGRESSION_TABLE_TEMPLATE_JAVA(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.KIE_PMML_REGRESSION_TABLE_TEMPLATE_JAVA) MethodCallExpr(com.github.javaparser.ast.expr.MethodCallExpr) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) RegressionTable(org.dmg.pmml.regression.RegressionTable) Collectors(java.util.stream.Collectors) FileUtils.getFileContent(org.kie.test.util.filesystem.FileUtils.getFileContent) DataField(org.dmg.pmml.DataField) SUPPORTED_NORMALIZATION_METHODS(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.SUPPORTED_NORMALIZATION_METHODS) List(java.util.List) CommonCompilationDTO(org.kie.pmml.compiler.api.dto.CommonCompilationDTO) SerializableFunction(org.kie.pmml.api.iinterfaces.SerializableFunction) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) KiePMMLRegressionTable(org.kie.pmml.models.regression.model.KiePMMLRegressionTable) ClassOrInterfaceDeclaration(com.github.javaparser.ast.body.ClassOrInterfaceDeclaration) IntStream(java.util.stream.IntStream) LambdaExpr(com.github.javaparser.ast.expr.LambdaExpr) ExpressionStmt(com.github.javaparser.ast.stmt.ExpressionStmt) BeforeClass(org.junit.BeforeClass) JavaParserUtils.getFromFileName(org.kie.pmml.compiler.commons.utils.JavaParserUtils.getFromFileName) Collectors.groupingBy(java.util.stream.Collectors.groupingBy) HashMap(java.util.HashMap) CastExpr(com.github.javaparser.ast.expr.CastExpr) AtomicReference(java.util.concurrent.atomic.AtomicReference) KIE_PMML_REGRESSION_TABLE_TEMPLATE(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.KIE_PMML_REGRESSION_TABLE_TEMPLATE) ArrayList(java.util.ArrayList) PredictorTerm(org.dmg.pmml.regression.PredictorTerm) PMMLModelTestUtils(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils) MiningField(org.dmg.pmml.MiningField) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) JavaParserUtils(org.kie.pmml.compiler.commons.utils.JavaParserUtils) Assert.assertTrue(org.junit.Assert.assertTrue) IOException(java.io.IOException) DataDictionary(org.dmg.pmml.DataDictionary) Test(org.junit.Test) CodegenTestUtils.commonValidateCompilation(org.kie.pmml.compiler.commons.testutils.CodegenTestUtils.commonValidateCompilation) KiePMMLModelUtils.getGeneratedClassName(org.kie.pmml.commons.utils.KiePMMLModelUtils.getGeneratedClassName) MethodReferenceExpr(com.github.javaparser.ast.expr.MethodReferenceExpr) KiePMMLTableSourceCategory(org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory) Assert.assertNull(org.junit.Assert.assertNull) MethodDeclaration(com.github.javaparser.ast.body.MethodDeclaration) CodegenTestUtils.commonValidateCompilationWithImports(org.kie.pmml.compiler.commons.testutils.CodegenTestUtils.commonValidateCompilationWithImports) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) UNSUPPORTED_NORMALIZATION_METHODS(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.UNSUPPORTED_NORMALIZATION_METHODS) Collections(java.util.Collections) Assert.assertEquals(org.junit.Assert.assertEquals) CategoricalPredictor(org.dmg.pmml.regression.CategoricalPredictor) ExpressionStmt(com.github.javaparser.ast.stmt.ExpressionStmt) MethodCallExpr(com.github.javaparser.ast.expr.MethodCallExpr)

Example 4 with CategoricalPredictor

use of org.dmg.pmml.regression.CategoricalPredictor in project drools by kiegroup.

the class PMMLModelTestUtils method getCategoricalPredictor.

public static CategoricalPredictor getCategoricalPredictor(String name, double value, double coefficient) {
    CategoricalPredictor toReturn = new CategoricalPredictor();
    toReturn.setField(FieldName.create(name));
    toReturn.setValue(value);
    toReturn.setCoefficient(coefficient);
    return toReturn;
}
Also used : CategoricalPredictor(org.dmg.pmml.regression.CategoricalPredictor)

Example 5 with CategoricalPredictor

use of org.dmg.pmml.regression.CategoricalPredictor in project drools by kiegroup.

the class PMMLModelTestUtils method getRegressionTable.

public static RegressionTable getRegressionTable(List<CategoricalPredictor> categoricalPredictors, List<NumericPredictor> numericPredictors, List<PredictorTerm> predictorTerms, double intercept, Object targetCategory) {
    RegressionTable toReturn = new RegressionTable();
    toReturn.setIntercept(intercept);
    toReturn.setTargetCategory(targetCategory);
    toReturn.addCategoricalPredictors(categoricalPredictors.toArray(new CategoricalPredictor[0]));
    toReturn.addNumericPredictors(numericPredictors.toArray(new NumericPredictor[0]));
    toReturn.addPredictorTerms(predictorTerms.toArray(new PredictorTerm[0]));
    return toReturn;
}
Also used : CategoricalPredictor(org.dmg.pmml.regression.CategoricalPredictor) PredictorTerm(org.dmg.pmml.regression.PredictorTerm) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) RegressionTable(org.dmg.pmml.regression.RegressionTable)

Aggregations

CategoricalPredictor (org.dmg.pmml.regression.CategoricalPredictor)8 NumericPredictor (org.dmg.pmml.regression.NumericPredictor)6 PredictorTerm (org.dmg.pmml.regression.PredictorTerm)6 BlockStmt (com.github.javaparser.ast.stmt.BlockStmt)4 ArrayList (java.util.ArrayList)4 RegressionTable (org.dmg.pmml.regression.RegressionTable)4 CompilationUnit (com.github.javaparser.ast.CompilationUnit)3 ClassOrInterfaceDeclaration (com.github.javaparser.ast.body.ClassOrInterfaceDeclaration)3 MethodDeclaration (com.github.javaparser.ast.body.MethodDeclaration)3 CastExpr (com.github.javaparser.ast.expr.CastExpr)3 Expression (com.github.javaparser.ast.expr.Expression)3 LambdaExpr (com.github.javaparser.ast.expr.LambdaExpr)3 MethodCallExpr (com.github.javaparser.ast.expr.MethodCallExpr)3 MethodReferenceExpr (com.github.javaparser.ast.expr.MethodReferenceExpr)3 NullLiteralExpr (com.github.javaparser.ast.expr.NullLiteralExpr)3 ExpressionStmt (com.github.javaparser.ast.stmt.ExpressionStmt)3 Arrays (java.util.Arrays)3 Collections (java.util.Collections)3 HashMap (java.util.HashMap)3 List (java.util.List)3