Search in sources :

Example 1 with SerializableFunction

use of org.kie.pmml.api.iinterfaces.SerializableFunction in project drools by kiegroup.

the class KiePMMLRegressionModelFactoryTest method evaluateRegressionTable.

private void evaluateRegressionTable(KiePMMLRegressionTable regressionTable, RegressionTable originalRegressionTable) {
    assertEquals(originalRegressionTable.getIntercept(), regressionTable.getIntercept());
    final Map<String, SerializableFunction<Double, Double>> numericFunctionMap = regressionTable.getNumericFunctionMap();
    for (NumericPredictor numericPredictor : originalRegressionTable.getNumericPredictors()) {
        assertTrue(numericFunctionMap.containsKey(numericPredictor.getName().getValue()));
    }
    final Map<String, SerializableFunction<String, Double>> categoricalFunctionMap = regressionTable.getCategoricalFunctionMap();
    for (CategoricalPredictor categoricalPredictor : originalRegressionTable.getCategoricalPredictors()) {
        assertTrue(categoricalFunctionMap.containsKey(categoricalPredictor.getName().getValue()));
    }
    final Map<String, SerializableFunction<Map<String, Object>, Double>> predictorTermsFunctionMap = regressionTable.getPredictorTermsFunctionMap();
    for (PredictorTerm predictorTerm : originalRegressionTable.getPredictorTerms()) {
        assertTrue(predictorTermsFunctionMap.containsKey(predictorTerm.getName().getValue()));
    }
}
Also used : CategoricalPredictor(org.dmg.pmml.regression.CategoricalPredictor) PMMLModelTestUtils.getCategoricalPredictor(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getCategoricalPredictor) PredictorTerm(org.dmg.pmml.regression.PredictorTerm) PMMLModelTestUtils.getPredictorTerm(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getPredictorTerm) SerializableFunction(org.kie.pmml.api.iinterfaces.SerializableFunction) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) PMMLModelTestUtils.getNumericPredictor(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getNumericPredictor)

Example 2 with SerializableFunction

use of org.kie.pmml.api.iinterfaces.SerializableFunction in project drools by kiegroup.

the class KiePMMLRegressionTableFactoryTest method getPredictorTermsMap.

@Test
public void getPredictorTermsMap() {
    final List<PredictorTerm> predictorTerms = IntStream.range(0, 3).mapToObj(index -> {
        String predictorName = "predictorName-" + index;
        double coefficient = 1.23 * index;
        String fieldRef = "fieldRef-" + index;
        return PMMLModelTestUtils.getPredictorTerm(predictorName, coefficient, Collections.singletonList(fieldRef));
    }).collect(Collectors.toList());
    Map<String, SerializableFunction<Map<String, Object>, Double>> retrieved = KiePMMLRegressionTableFactory.getPredictorTermsMap(predictorTerms);
    assertEquals(predictorTerms.size(), retrieved.size());
    IntStream.range(0, predictorTerms.size()).forEach(index -> {
        PredictorTerm predictorTerm = predictorTerms.get(index);
        assertTrue(retrieved.containsKey(predictorTerm.getName().getValue()));
    });
}
Also used : RESULT_FEATURE(org.kie.pmml.api.enums.RESULT_FEATURE) Arrays(java.util.Arrays) KiePMMLModelUtils.getSanitizedVariableName(org.kie.pmml.commons.utils.KiePMMLModelUtils.getSanitizedVariableName) MiningSchema(org.dmg.pmml.MiningSchema) KiePMMLOutputField(org.kie.pmml.commons.model.KiePMMLOutputField) FieldName(org.dmg.pmml.FieldName) OpType(org.dmg.pmml.OpType) NullLiteralExpr(com.github.javaparser.ast.expr.NullLiteralExpr) CategoricalPredictor(org.dmg.pmml.regression.CategoricalPredictor) GETKIEPMML_TABLE(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.GETKIEPMML_TABLE) TestCase.assertNotNull(junit.framework.TestCase.assertNotNull) Map(java.util.Map) Expression(com.github.javaparser.ast.expr.Expression) Assert.fail(org.junit.Assert.fail) CompilationUnit(com.github.javaparser.ast.CompilationUnit) PMML(org.dmg.pmml.PMML) RegressionModel(org.dmg.pmml.regression.RegressionModel) KIE_PMML_REGRESSION_TABLE_TEMPLATE_JAVA(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.KIE_PMML_REGRESSION_TABLE_TEMPLATE_JAVA) MethodCallExpr(com.github.javaparser.ast.expr.MethodCallExpr) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) RegressionTable(org.dmg.pmml.regression.RegressionTable) Collectors(java.util.stream.Collectors) FileUtils.getFileContent(org.kie.test.util.filesystem.FileUtils.getFileContent) DataField(org.dmg.pmml.DataField) SUPPORTED_NORMALIZATION_METHODS(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.SUPPORTED_NORMALIZATION_METHODS) List(java.util.List) CommonCompilationDTO(org.kie.pmml.compiler.api.dto.CommonCompilationDTO) SerializableFunction(org.kie.pmml.api.iinterfaces.SerializableFunction) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) KiePMMLRegressionTable(org.kie.pmml.models.regression.model.KiePMMLRegressionTable) ClassOrInterfaceDeclaration(com.github.javaparser.ast.body.ClassOrInterfaceDeclaration) IntStream(java.util.stream.IntStream) LambdaExpr(com.github.javaparser.ast.expr.LambdaExpr) ExpressionStmt(com.github.javaparser.ast.stmt.ExpressionStmt) BeforeClass(org.junit.BeforeClass) JavaParserUtils.getFromFileName(org.kie.pmml.compiler.commons.utils.JavaParserUtils.getFromFileName) Collectors.groupingBy(java.util.stream.Collectors.groupingBy) HashMap(java.util.HashMap) CastExpr(com.github.javaparser.ast.expr.CastExpr) AtomicReference(java.util.concurrent.atomic.AtomicReference) KIE_PMML_REGRESSION_TABLE_TEMPLATE(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.KIE_PMML_REGRESSION_TABLE_TEMPLATE) ArrayList(java.util.ArrayList) PredictorTerm(org.dmg.pmml.regression.PredictorTerm) PMMLModelTestUtils(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils) MiningField(org.dmg.pmml.MiningField) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) JavaParserUtils(org.kie.pmml.compiler.commons.utils.JavaParserUtils) Assert.assertTrue(org.junit.Assert.assertTrue) IOException(java.io.IOException) DataDictionary(org.dmg.pmml.DataDictionary) Test(org.junit.Test) CodegenTestUtils.commonValidateCompilation(org.kie.pmml.compiler.commons.testutils.CodegenTestUtils.commonValidateCompilation) KiePMMLModelUtils.getGeneratedClassName(org.kie.pmml.commons.utils.KiePMMLModelUtils.getGeneratedClassName) MethodReferenceExpr(com.github.javaparser.ast.expr.MethodReferenceExpr) KiePMMLTableSourceCategory(org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory) Assert.assertNull(org.junit.Assert.assertNull) MethodDeclaration(com.github.javaparser.ast.body.MethodDeclaration) CodegenTestUtils.commonValidateCompilationWithImports(org.kie.pmml.compiler.commons.testutils.CodegenTestUtils.commonValidateCompilationWithImports) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) UNSUPPORTED_NORMALIZATION_METHODS(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.UNSUPPORTED_NORMALIZATION_METHODS) Collections(java.util.Collections) Assert.assertEquals(org.junit.Assert.assertEquals) PredictorTerm(org.dmg.pmml.regression.PredictorTerm) SerializableFunction(org.kie.pmml.api.iinterfaces.SerializableFunction) Test(org.junit.Test)

Aggregations

CategoricalPredictor (org.dmg.pmml.regression.CategoricalPredictor)2 NumericPredictor (org.dmg.pmml.regression.NumericPredictor)2 PredictorTerm (org.dmg.pmml.regression.PredictorTerm)2 CompilationUnit (com.github.javaparser.ast.CompilationUnit)1 ClassOrInterfaceDeclaration (com.github.javaparser.ast.body.ClassOrInterfaceDeclaration)1 MethodDeclaration (com.github.javaparser.ast.body.MethodDeclaration)1 CastExpr (com.github.javaparser.ast.expr.CastExpr)1 Expression (com.github.javaparser.ast.expr.Expression)1 LambdaExpr (com.github.javaparser.ast.expr.LambdaExpr)1 MethodCallExpr (com.github.javaparser.ast.expr.MethodCallExpr)1 MethodReferenceExpr (com.github.javaparser.ast.expr.MethodReferenceExpr)1 NullLiteralExpr (com.github.javaparser.ast.expr.NullLiteralExpr)1 BlockStmt (com.github.javaparser.ast.stmt.BlockStmt)1 ExpressionStmt (com.github.javaparser.ast.stmt.ExpressionStmt)1 IOException (java.io.IOException)1 ArrayList (java.util.ArrayList)1 Arrays (java.util.Arrays)1 Collections (java.util.Collections)1 HashMap (java.util.HashMap)1 List (java.util.List)1