Search in sources :

Example 6 with PredictorTerm

use of org.dmg.pmml.regression.PredictorTerm in project drools by kiegroup.

the class PMMLModelTestUtils method getPredictorTerm.

public static PredictorTerm getPredictorTerm(String name, double coefficient, List<String> fieldRefNames) {
    PredictorTerm toReturn = new PredictorTerm();
    toReturn.setName(FieldName.create(name));
    toReturn.setCoefficient(coefficient);
    toReturn.addFieldRefs(fieldRefNames.stream().map(PMMLModelTestUtils::getFieldRef).toArray(FieldRef[]::new));
    return toReturn;
}
Also used : PredictorTerm(org.dmg.pmml.regression.PredictorTerm) FieldRef(org.dmg.pmml.FieldRef)

Example 7 with PredictorTerm

use of org.dmg.pmml.regression.PredictorTerm in project drools by kiegroup.

the class KiePMMLRegressionModelFactoryTest method setup.

@BeforeClass
public static void setup() {
    Random random = new Random();
    Set<String> fieldNames = new HashSet<>();
    regressionTables = IntStream.range(0, 3).mapToObj(i -> {
        List<CategoricalPredictor> categoricalPredictors = new ArrayList<>();
        List<NumericPredictor> numericPredictors = new ArrayList<>();
        List<PredictorTerm> predictorTerms = new ArrayList<>();
        IntStream.range(0, 3).forEach(j -> {
            String catFieldName = "CatPred-" + j;
            String numFieldName = "NumPred-" + j;
            categoricalPredictors.add(getCategoricalPredictor(catFieldName, random.nextDouble(), random.nextDouble()));
            numericPredictors.add(getNumericPredictor(numFieldName, random.nextInt(), random.nextDouble()));
            predictorTerms.add(getPredictorTerm("PredTerm-" + j, random.nextDouble(), Arrays.asList(catFieldName, numFieldName)));
            fieldNames.add(catFieldName);
            fieldNames.add(numFieldName);
        });
        return getRegressionTable(categoricalPredictors, numericPredictors, predictorTerms, tableIntercept + random.nextDouble(), tableTargetCategory + "-" + i);
    }).collect(Collectors.toList());
    dataFields = new ArrayList<>();
    miningFields = new ArrayList<>();
    fieldNames.forEach(fieldName -> {
        dataFields.add(getDataField(fieldName, OpType.CATEGORICAL, DataType.STRING));
        miningFields.add(getMiningField(fieldName, MiningField.UsageType.ACTIVE));
    });
    targetMiningField = miningFields.get(0);
    targetMiningField.setUsageType(MiningField.UsageType.TARGET);
    dataDictionary = getDataDictionary(dataFields);
    transformationDictionary = new TransformationDictionary();
    miningSchema = getMiningSchema(miningFields);
    regressionModel = getRegressionModel(modelName, MiningFunction.REGRESSION, miningSchema, regressionTables);
    COMPILATION_UNIT = getFromFileName(KIE_PMML_REGRESSION_MODEL_TEMPLATE_JAVA);
    MODEL_TEMPLATE = COMPILATION_UNIT.getClassByName(KIE_PMML_REGRESSION_MODEL_TEMPLATE).get();
    pmml = new PMML();
    pmml.setDataDictionary(dataDictionary);
    pmml.setTransformationDictionary(transformationDictionary);
    pmml.addModels(regressionModel);
}
Also used : PredictorTerm(org.dmg.pmml.regression.PredictorTerm) PMMLModelTestUtils.getPredictorTerm(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getPredictorTerm) TransformationDictionary(org.dmg.pmml.TransformationDictionary) ArrayList(java.util.ArrayList) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) PMMLModelTestUtils.getNumericPredictor(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getNumericPredictor) CategoricalPredictor(org.dmg.pmml.regression.CategoricalPredictor) PMMLModelTestUtils.getCategoricalPredictor(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getCategoricalPredictor) Random(java.util.Random) PMML(org.dmg.pmml.PMML) HashSet(java.util.HashSet) BeforeClass(org.junit.BeforeClass)

Example 8 with PredictorTerm

use of org.dmg.pmml.regression.PredictorTerm in project drools by kiegroup.

the class KiePMMLRegressionTableFactoryTest method getPredictorTermsMap.

@Test
public void getPredictorTermsMap() {
    final List<PredictorTerm> predictorTerms = IntStream.range(0, 3).mapToObj(index -> {
        String predictorName = "predictorName-" + index;
        double coefficient = 1.23 * index;
        String fieldRef = "fieldRef-" + index;
        return PMMLModelTestUtils.getPredictorTerm(predictorName, coefficient, Collections.singletonList(fieldRef));
    }).collect(Collectors.toList());
    Map<String, SerializableFunction<Map<String, Object>, Double>> retrieved = KiePMMLRegressionTableFactory.getPredictorTermsMap(predictorTerms);
    assertEquals(predictorTerms.size(), retrieved.size());
    IntStream.range(0, predictorTerms.size()).forEach(index -> {
        PredictorTerm predictorTerm = predictorTerms.get(index);
        assertTrue(retrieved.containsKey(predictorTerm.getName().getValue()));
    });
}
Also used : RESULT_FEATURE(org.kie.pmml.api.enums.RESULT_FEATURE) Arrays(java.util.Arrays) KiePMMLModelUtils.getSanitizedVariableName(org.kie.pmml.commons.utils.KiePMMLModelUtils.getSanitizedVariableName) MiningSchema(org.dmg.pmml.MiningSchema) KiePMMLOutputField(org.kie.pmml.commons.model.KiePMMLOutputField) FieldName(org.dmg.pmml.FieldName) OpType(org.dmg.pmml.OpType) NullLiteralExpr(com.github.javaparser.ast.expr.NullLiteralExpr) CategoricalPredictor(org.dmg.pmml.regression.CategoricalPredictor) GETKIEPMML_TABLE(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.GETKIEPMML_TABLE) TestCase.assertNotNull(junit.framework.TestCase.assertNotNull) Map(java.util.Map) Expression(com.github.javaparser.ast.expr.Expression) Assert.fail(org.junit.Assert.fail) CompilationUnit(com.github.javaparser.ast.CompilationUnit) PMML(org.dmg.pmml.PMML) RegressionModel(org.dmg.pmml.regression.RegressionModel) KIE_PMML_REGRESSION_TABLE_TEMPLATE_JAVA(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.KIE_PMML_REGRESSION_TABLE_TEMPLATE_JAVA) MethodCallExpr(com.github.javaparser.ast.expr.MethodCallExpr) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) RegressionTable(org.dmg.pmml.regression.RegressionTable) Collectors(java.util.stream.Collectors) FileUtils.getFileContent(org.kie.test.util.filesystem.FileUtils.getFileContent) DataField(org.dmg.pmml.DataField) SUPPORTED_NORMALIZATION_METHODS(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.SUPPORTED_NORMALIZATION_METHODS) List(java.util.List) CommonCompilationDTO(org.kie.pmml.compiler.api.dto.CommonCompilationDTO) SerializableFunction(org.kie.pmml.api.iinterfaces.SerializableFunction) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) KiePMMLRegressionTable(org.kie.pmml.models.regression.model.KiePMMLRegressionTable) ClassOrInterfaceDeclaration(com.github.javaparser.ast.body.ClassOrInterfaceDeclaration) IntStream(java.util.stream.IntStream) LambdaExpr(com.github.javaparser.ast.expr.LambdaExpr) ExpressionStmt(com.github.javaparser.ast.stmt.ExpressionStmt) BeforeClass(org.junit.BeforeClass) JavaParserUtils.getFromFileName(org.kie.pmml.compiler.commons.utils.JavaParserUtils.getFromFileName) Collectors.groupingBy(java.util.stream.Collectors.groupingBy) HashMap(java.util.HashMap) CastExpr(com.github.javaparser.ast.expr.CastExpr) AtomicReference(java.util.concurrent.atomic.AtomicReference) KIE_PMML_REGRESSION_TABLE_TEMPLATE(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.KIE_PMML_REGRESSION_TABLE_TEMPLATE) ArrayList(java.util.ArrayList) PredictorTerm(org.dmg.pmml.regression.PredictorTerm) PMMLModelTestUtils(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils) MiningField(org.dmg.pmml.MiningField) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) JavaParserUtils(org.kie.pmml.compiler.commons.utils.JavaParserUtils) Assert.assertTrue(org.junit.Assert.assertTrue) IOException(java.io.IOException) DataDictionary(org.dmg.pmml.DataDictionary) Test(org.junit.Test) CodegenTestUtils.commonValidateCompilation(org.kie.pmml.compiler.commons.testutils.CodegenTestUtils.commonValidateCompilation) KiePMMLModelUtils.getGeneratedClassName(org.kie.pmml.commons.utils.KiePMMLModelUtils.getGeneratedClassName) MethodReferenceExpr(com.github.javaparser.ast.expr.MethodReferenceExpr) KiePMMLTableSourceCategory(org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory) Assert.assertNull(org.junit.Assert.assertNull) MethodDeclaration(com.github.javaparser.ast.body.MethodDeclaration) CodegenTestUtils.commonValidateCompilationWithImports(org.kie.pmml.compiler.commons.testutils.CodegenTestUtils.commonValidateCompilationWithImports) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) UNSUPPORTED_NORMALIZATION_METHODS(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.UNSUPPORTED_NORMALIZATION_METHODS) Collections(java.util.Collections) Assert.assertEquals(org.junit.Assert.assertEquals) PredictorTerm(org.dmg.pmml.regression.PredictorTerm) SerializableFunction(org.kie.pmml.api.iinterfaces.SerializableFunction) Test(org.junit.Test)

Example 9 with PredictorTerm

use of org.dmg.pmml.regression.PredictorTerm in project drools by kiegroup.

the class KiePMMLRegressionTableFactoryTest method getPredictorTermFunction.

@Test
public void getPredictorTermFunction() throws IOException {
    String predictorName = "predictorName";
    double coefficient = 23.12;
    String fieldRef = "fieldRef";
    PredictorTerm predictorTerm = PMMLModelTestUtils.getPredictorTerm(predictorName, coefficient, Collections.singletonList(fieldRef));
    LambdaExpr retrieved = KiePMMLRegressionTableFactory.getPredictorTermFunction(predictorTerm);
    String text = getFileContent(TEST_07_SOURCE);
    Expression expected = JavaParserUtils.parseExpression(String.format(text, fieldRef, coefficient));
    assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
}
Also used : PredictorTerm(org.dmg.pmml.regression.PredictorTerm) Expression(com.github.javaparser.ast.expr.Expression) LambdaExpr(com.github.javaparser.ast.expr.LambdaExpr) Test(org.junit.Test)

Aggregations

PredictorTerm (org.dmg.pmml.regression.PredictorTerm)9 CategoricalPredictor (org.dmg.pmml.regression.CategoricalPredictor)6 NumericPredictor (org.dmg.pmml.regression.NumericPredictor)6 Expression (com.github.javaparser.ast.expr.Expression)4 LambdaExpr (com.github.javaparser.ast.expr.LambdaExpr)4 HashMap (java.util.HashMap)4 Map (java.util.Map)4 RegressionTable (org.dmg.pmml.regression.RegressionTable)4 Test (org.junit.Test)4 CompilationUnit (com.github.javaparser.ast.CompilationUnit)3 ClassOrInterfaceDeclaration (com.github.javaparser.ast.body.ClassOrInterfaceDeclaration)3 MethodDeclaration (com.github.javaparser.ast.body.MethodDeclaration)3 CastExpr (com.github.javaparser.ast.expr.CastExpr)3 MethodCallExpr (com.github.javaparser.ast.expr.MethodCallExpr)3 MethodReferenceExpr (com.github.javaparser.ast.expr.MethodReferenceExpr)3 NullLiteralExpr (com.github.javaparser.ast.expr.NullLiteralExpr)3 BlockStmt (com.github.javaparser.ast.stmt.BlockStmt)3 ExpressionStmt (com.github.javaparser.ast.stmt.ExpressionStmt)3 Arrays (java.util.Arrays)3 Collections (java.util.Collections)3