Search in sources :

Example 6 with CategoricalPredictor

use of org.dmg.pmml.regression.CategoricalPredictor in project drools by kiegroup.

the class KiePMMLRegressionTableFactory method getCategoricalPredictorsExpressions.

/**
 * Create the <b>CategoricalPredictor</b>s lambda <code>Expression</code>s map
 * @param categoricalPredictors
 * @param body
 * @return
 */
static Map<String, Expression> getCategoricalPredictorsExpressions(final List<CategoricalPredictor> categoricalPredictors, final BlockStmt body, final String variableName) {
    final Map<String, List<CategoricalPredictor>> groupedCollectors = categoricalPredictors.stream().collect(groupingBy(categoricalPredictor -> categoricalPredictor.getField().getValue()));
    final String categoricalPredictorMapNameBase = getSanitizedVariableName(String.format("%sMap", variableName));
    final AtomicInteger counter = new AtomicInteger();
    return groupedCollectors.entrySet().stream().map(entry -> {
        final String categoricalPredictorMapName = String.format(VARIABLE_NAME_TEMPLATE, categoricalPredictorMapNameBase, counter.getAndIncrement());
        populateWithGroupedCategoricalPredictorMap(entry.getValue(), body, categoricalPredictorMapName);
        return new AbstractMap.SimpleEntry<>(entry.getKey(), getCategoricalPredictorExpression(categoricalPredictorMapName));
    }).collect(Collectors.toMap(AbstractMap.SimpleEntry::getKey, AbstractMap.SimpleEntry::getValue));
}
Also used : Arrays(java.util.Arrays) ClassOrInterfaceType(com.github.javaparser.ast.type.ClassOrInterfaceType) CommonCodegenUtils.createPopulatedHashMap(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.createPopulatedHashMap) KiePMMLModelUtils.getSanitizedVariableName(org.kie.pmml.commons.utils.KiePMMLModelUtils.getSanitizedVariableName) LoggerFactory(org.slf4j.LoggerFactory) MISSING_VARIABLE_IN_BODY(org.kie.pmml.commons.Constants.MISSING_VARIABLE_IN_BODY) MAIN_CLASS_NOT_FOUND(org.kie.pmml.compiler.commons.utils.JavaParserUtils.MAIN_CLASS_NOT_FOUND) NullLiteralExpr(com.github.javaparser.ast.expr.NullLiteralExpr) CategoricalPredictor(org.dmg.pmml.regression.CategoricalPredictor) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ObjectCreationExpr(com.github.javaparser.ast.expr.ObjectCreationExpr) Map(java.util.Map) Expression(com.github.javaparser.ast.expr.Expression) CompilationUnit(com.github.javaparser.ast.CompilationUnit) NodeList(com.github.javaparser.ast.NodeList) UnknownType(com.github.javaparser.ast.type.UnknownType) CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypeNames(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypeNames) RegressionModel(org.dmg.pmml.regression.RegressionModel) MethodCallExpr(com.github.javaparser.ast.expr.MethodCallExpr) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) JavaParserUtils.getFullClassName(org.kie.pmml.compiler.commons.utils.JavaParserUtils.getFullClassName) MISSING_BODY_TEMPLATE(org.kie.pmml.commons.Constants.MISSING_BODY_TEMPLATE) UUID(java.util.UUID) RegressionTable(org.dmg.pmml.regression.RegressionTable) Collectors(java.util.stream.Collectors) StringLiteralExpr(com.github.javaparser.ast.expr.StringLiteralExpr) VariableDeclarationExpr(com.github.javaparser.ast.expr.VariableDeclarationExpr) Objects(java.util.Objects) List(java.util.List) CommonCodegenUtils.getExpressionForObject(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.getExpressionForObject) SerializableFunction(org.kie.pmml.api.iinterfaces.SerializableFunction) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) KiePMMLRegressionTable(org.kie.pmml.models.regression.model.KiePMMLRegressionTable) ClassOrInterfaceDeclaration(com.github.javaparser.ast.body.ClassOrInterfaceDeclaration) TO_RETURN(org.kie.pmml.commons.Constants.TO_RETURN) VARIABLE_NAME_TEMPLATE(org.kie.pmml.commons.Constants.VARIABLE_NAME_TEMPLATE) LambdaExpr(com.github.javaparser.ast.expr.LambdaExpr) ExpressionStmt(com.github.javaparser.ast.stmt.ExpressionStmt) JavaParserUtils.getFromFileName(org.kie.pmml.compiler.commons.utils.JavaParserUtils.getFromFileName) Parameter(com.github.javaparser.ast.body.Parameter) Collectors.groupingBy(java.util.stream.Collectors.groupingBy) HashMap(java.util.HashMap) CastExpr(com.github.javaparser.ast.expr.CastExpr) AtomicReference(java.util.concurrent.atomic.AtomicReference) PredictorTerm(org.dmg.pmml.regression.PredictorTerm) LinkedHashMap(java.util.LinkedHashMap) VariableDeclarator(com.github.javaparser.ast.body.VariableDeclarator) KiePMMLInternalException(org.kie.pmml.api.exceptions.KiePMMLInternalException) CommonCodegenUtils.addMapPopulationExpressions(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.addMapPopulationExpressions) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) AbstractKiePMMLTable(org.kie.pmml.models.regression.model.AbstractKiePMMLTable) JavaParserUtils(org.kie.pmml.compiler.commons.utils.JavaParserUtils) Logger(org.slf4j.Logger) CommonCodegenUtils.getChainedMethodCallExprFrom(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.getChainedMethodCallExprFrom) NameExpr(com.github.javaparser.ast.expr.NameExpr) MethodReferenceExpr(com.github.javaparser.ast.expr.MethodReferenceExpr) CommonCodegenUtils.getVariableDeclarator(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.getVariableDeclarator) KiePMMLTableSourceCategory(org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory) AbstractMap(java.util.AbstractMap) MethodDeclaration(com.github.javaparser.ast.body.MethodDeclaration) MISSING_VARIABLE_INITIALIZER_TEMPLATE(org.kie.pmml.commons.Constants.MISSING_VARIABLE_INITIALIZER_TEMPLATE) Collections(java.util.Collections) KiePMMLException(org.kie.pmml.api.exceptions.KiePMMLException) AbstractMap(java.util.AbstractMap) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) NodeList(com.github.javaparser.ast.NodeList) List(java.util.List)

Example 7 with CategoricalPredictor

use of org.dmg.pmml.regression.CategoricalPredictor in project drools by kiegroup.

the class KiePMMLRegressionModelFactoryTest method setup.

@BeforeClass
public static void setup() {
    Random random = new Random();
    Set<String> fieldNames = new HashSet<>();
    regressionTables = IntStream.range(0, 3).mapToObj(i -> {
        List<CategoricalPredictor> categoricalPredictors = new ArrayList<>();
        List<NumericPredictor> numericPredictors = new ArrayList<>();
        List<PredictorTerm> predictorTerms = new ArrayList<>();
        IntStream.range(0, 3).forEach(j -> {
            String catFieldName = "CatPred-" + j;
            String numFieldName = "NumPred-" + j;
            categoricalPredictors.add(getCategoricalPredictor(catFieldName, random.nextDouble(), random.nextDouble()));
            numericPredictors.add(getNumericPredictor(numFieldName, random.nextInt(), random.nextDouble()));
            predictorTerms.add(getPredictorTerm("PredTerm-" + j, random.nextDouble(), Arrays.asList(catFieldName, numFieldName)));
            fieldNames.add(catFieldName);
            fieldNames.add(numFieldName);
        });
        return getRegressionTable(categoricalPredictors, numericPredictors, predictorTerms, tableIntercept + random.nextDouble(), tableTargetCategory + "-" + i);
    }).collect(Collectors.toList());
    dataFields = new ArrayList<>();
    miningFields = new ArrayList<>();
    fieldNames.forEach(fieldName -> {
        dataFields.add(getDataField(fieldName, OpType.CATEGORICAL, DataType.STRING));
        miningFields.add(getMiningField(fieldName, MiningField.UsageType.ACTIVE));
    });
    targetMiningField = miningFields.get(0);
    targetMiningField.setUsageType(MiningField.UsageType.TARGET);
    dataDictionary = getDataDictionary(dataFields);
    transformationDictionary = new TransformationDictionary();
    miningSchema = getMiningSchema(miningFields);
    regressionModel = getRegressionModel(modelName, MiningFunction.REGRESSION, miningSchema, regressionTables);
    COMPILATION_UNIT = getFromFileName(KIE_PMML_REGRESSION_MODEL_TEMPLATE_JAVA);
    MODEL_TEMPLATE = COMPILATION_UNIT.getClassByName(KIE_PMML_REGRESSION_MODEL_TEMPLATE).get();
    pmml = new PMML();
    pmml.setDataDictionary(dataDictionary);
    pmml.setTransformationDictionary(transformationDictionary);
    pmml.addModels(regressionModel);
}
Also used : PredictorTerm(org.dmg.pmml.regression.PredictorTerm) PMMLModelTestUtils.getPredictorTerm(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getPredictorTerm) TransformationDictionary(org.dmg.pmml.TransformationDictionary) ArrayList(java.util.ArrayList) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) PMMLModelTestUtils.getNumericPredictor(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getNumericPredictor) CategoricalPredictor(org.dmg.pmml.regression.CategoricalPredictor) PMMLModelTestUtils.getCategoricalPredictor(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getCategoricalPredictor) Random(java.util.Random) PMML(org.dmg.pmml.PMML) HashSet(java.util.HashSet) BeforeClass(org.junit.BeforeClass)

Example 8 with CategoricalPredictor

use of org.dmg.pmml.regression.CategoricalPredictor in project drools by kiegroup.

the class KiePMMLRegressionTableFactoryTest method populateWithGroupedCategoricalPredictorMap.

@Test
public void populateWithGroupedCategoricalPredictorMap() throws IOException {
    final List<CategoricalPredictor> categoricalPredictors = new ArrayList<>();
    for (int i = 0; i < 3; i++) {
        String predictorName = "predictorName-" + i;
        double coefficient = 1.23 * i;
        categoricalPredictors.add(PMMLModelTestUtils.getCategoricalPredictor(predictorName, i, coefficient));
    }
    final BlockStmt toPopulate = new BlockStmt();
    final String categoricalPredictorMapName = "categoricalPredictorMapName";
    KiePMMLRegressionTableFactory.populateWithGroupedCategoricalPredictorMap(categoricalPredictors, toPopulate, categoricalPredictorMapName);
    String text = getFileContent(TEST_04_SOURCE);
    BlockStmt expected = JavaParserUtils.parseBlock(String.format(text, categoricalPredictorMapName, categoricalPredictors.get(0).getValue(), categoricalPredictors.get(0).getCoefficient(), categoricalPredictors.get(1).getValue(), categoricalPredictors.get(1).getCoefficient(), categoricalPredictors.get(2).getValue(), categoricalPredictors.get(2).getCoefficient()));
    assertTrue(JavaParserUtils.equalsNode(expected, toPopulate));
}
Also used : CategoricalPredictor(org.dmg.pmml.regression.CategoricalPredictor) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) ArrayList(java.util.ArrayList) Test(org.junit.Test)

Aggregations

CategoricalPredictor (org.dmg.pmml.regression.CategoricalPredictor)8 NumericPredictor (org.dmg.pmml.regression.NumericPredictor)6 PredictorTerm (org.dmg.pmml.regression.PredictorTerm)6 BlockStmt (com.github.javaparser.ast.stmt.BlockStmt)4 ArrayList (java.util.ArrayList)4 RegressionTable (org.dmg.pmml.regression.RegressionTable)4 CompilationUnit (com.github.javaparser.ast.CompilationUnit)3 ClassOrInterfaceDeclaration (com.github.javaparser.ast.body.ClassOrInterfaceDeclaration)3 MethodDeclaration (com.github.javaparser.ast.body.MethodDeclaration)3 CastExpr (com.github.javaparser.ast.expr.CastExpr)3 Expression (com.github.javaparser.ast.expr.Expression)3 LambdaExpr (com.github.javaparser.ast.expr.LambdaExpr)3 MethodCallExpr (com.github.javaparser.ast.expr.MethodCallExpr)3 MethodReferenceExpr (com.github.javaparser.ast.expr.MethodReferenceExpr)3 NullLiteralExpr (com.github.javaparser.ast.expr.NullLiteralExpr)3 ExpressionStmt (com.github.javaparser.ast.stmt.ExpressionStmt)3 Arrays (java.util.Arrays)3 Collections (java.util.Collections)3 HashMap (java.util.HashMap)3 List (java.util.List)3