Search in sources :

Example 1 with RegressionCompilationDTO

use of org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO in project drools by kiegroup.

the class KiePMMLRegressionModelFactoryTest method setStaticGetter.

@Test
public void setStaticGetter() throws IOException {
    String nestedTable = "NestedTable";
    MINING_FUNCTION miningFunction = MINING_FUNCTION.byName(regressionModel.getMiningFunction().value());
    final ClassOrInterfaceDeclaration modelTemplate = MODEL_TEMPLATE.clone();
    final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
    final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, new ArrayList<>(), regressionModel.getNormalizationMethod());
    KiePMMLRegressionModelFactory.setStaticGetter(compilationDTO, modelTemplate, nestedTable);
    Map<Integer, Expression> superInvocationExpressionsMap = new HashMap<>();
    superInvocationExpressionsMap.put(0, new NameExpr(String.format("\"%s\"", regressionModel.getModelName())));
    Map<String, Expression> assignExpressionMap = new HashMap<>();
    assignExpressionMap.put("targetField", new StringLiteralExpr(targetMiningField.getName().getValue()));
    assignExpressionMap.put("miningFunction", new NameExpr(miningFunction.getClass().getName() + "." + miningFunction.name()));
    assignExpressionMap.put("pmmlMODEL", new NameExpr(PMML_MODEL.class.getName() + "." + PMML_MODEL.REGRESSION_MODEL.name()));
    MethodCallExpr methodCallExpr = new MethodCallExpr();
    methodCallExpr.setScope(new NameExpr(nestedTable));
    methodCallExpr.setName(GETKIEPMML_TABLE);
    assignExpressionMap.put("regressionTable", methodCallExpr);
    MethodDeclaration retrieved = modelTemplate.getMethodsByName(GET_MODEL).get(0);
    String text = getFileContent(TEST_01_SOURCE);
    MethodDeclaration expected = JavaParserUtils.parseMethod(text);
    assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
}
Also used : ClassOrInterfaceDeclaration(com.github.javaparser.ast.body.ClassOrInterfaceDeclaration) HashMap(java.util.HashMap) MethodDeclaration(com.github.javaparser.ast.body.MethodDeclaration) NameExpr(com.github.javaparser.ast.expr.NameExpr) StringLiteralExpr(com.github.javaparser.ast.expr.StringLiteralExpr) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) PMMLModelTestUtils.getRegressionModel(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRegressionModel) RegressionModel(org.dmg.pmml.regression.RegressionModel) KiePMMLRegressionModel(org.kie.pmml.models.regression.model.KiePMMLRegressionModel) Expression(com.github.javaparser.ast.expr.Expression) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) MINING_FUNCTION(org.kie.pmml.api.enums.MINING_FUNCTION) MethodCallExpr(com.github.javaparser.ast.expr.MethodCallExpr) Test(org.junit.Test)

Example 2 with RegressionCompilationDTO

use of org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO in project drools by kiegroup.

the class KiePMMLRegressionModelFactory method getRegressionTablesMap.

// not-public code-generation
static Map<String, KiePMMLTableSourceCategory> getRegressionTablesMap(final RegressionCompilationDTO compilationDTO) {
    Map<String, KiePMMLTableSourceCategory> toReturn;
    if (compilationDTO.isRegression()) {
        final List<RegressionTable> regressionTables = Collections.singletonList(compilationDTO.getModel().getRegressionTables().get(0));
        final RegressionCompilationDTO regressionCompilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(compilationDTO, regressionTables, compilationDTO.getModel().getNormalizationMethod());
        toReturn = KiePMMLRegressionTableFactory.getRegressionTableBuilders(regressionCompilationDTO);
    } else {
        final List<RegressionTable> regressionTables = compilationDTO.getModel().getRegressionTables();
        final RegressionCompilationDTO regressionCompilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(compilationDTO, regressionTables, RegressionModel.NormalizationMethod.NONE);
        toReturn = KiePMMLClassificationTableFactory.getClassificationTableBuilders(regressionCompilationDTO);
    }
    return toReturn;
}
Also used : KiePMMLTableSourceCategory(org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) RegressionTable(org.dmg.pmml.regression.RegressionTable)

Example 3 with RegressionCompilationDTO

use of org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO in project drools by kiegroup.

the class KiePMMLClassificationTableFactoryTest method getClassificationTableBuilders.

@Test
public void getClassificationTableBuilders() {
    RegressionTable regressionTableProf = getRegressionTable(3.5, "professional");
    RegressionTable regressionTableCler = getRegressionTable(27.4, "clerical");
    OutputField outputFieldCat = getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
    OutputField outputFieldNum = getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
    OutputField outputFieldPrev = getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
    String targetField = "targetField";
    DataField dataField = new DataField();
    dataField.setName(FieldName.create(targetField));
    dataField.setOpType(OpType.CATEGORICAL);
    DataDictionary dataDictionary = new DataDictionary();
    dataDictionary.addDataFields(dataField);
    RegressionModel regressionModel = new RegressionModel();
    regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
    regressionModel.addRegressionTables(regressionTableProf, regressionTableCler);
    regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
    Output output = new Output();
    output.addOutputFields(outputFieldCat, outputFieldNum, outputFieldPrev);
    regressionModel.setOutput(output);
    MiningField miningField = new MiningField();
    miningField.setUsageType(MiningField.UsageType.TARGET);
    miningField.setName(dataField.getName());
    MiningSchema miningSchema = new MiningSchema();
    miningSchema.addMiningFields(miningField);
    regressionModel.setMiningSchema(miningSchema);
    PMML pmml = new PMML();
    pmml.setDataDictionary(dataDictionary);
    pmml.addModels(regressionModel);
    final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
    final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
    Map<String, KiePMMLTableSourceCategory> retrieved = KiePMMLClassificationTableFactory.getClassificationTableBuilders(compilationDTO);
    assertNotNull(retrieved);
    assertEquals(3, retrieved.size());
    retrieved.values().forEach(kiePMMLTableSourceCategory -> commonValidateKiePMMLRegressionTable(kiePMMLTableSourceCategory.getSource()));
    Map<String, String> sources = retrieved.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, stringKiePMMLTableSourceCategoryEntry -> stringKiePMMLTableSourceCategoryEntry.getValue().getSource()));
    commonValidateCompilation(sources);
}
Also used : GETKIEPMML_TABLE(org.kie.pmml.models.regression.compiler.factories.KiePMMLClassificationTableFactory.GETKIEPMML_TABLE) BeforeClass(org.junit.BeforeClass) OutputField(org.dmg.pmml.OutputField) JavaParserUtils.getFromFileName(org.kie.pmml.compiler.commons.utils.JavaParserUtils.getFromFileName) KIE_PMML_CLASSIFICATION_TABLE_TEMPLATE(org.kie.pmml.models.regression.compiler.factories.KiePMMLClassificationTableFactory.KIE_PMML_CLASSIFICATION_TABLE_TEMPLATE) ResultFeature(org.dmg.pmml.ResultFeature) MiningSchema(org.dmg.pmml.MiningSchema) OP_TYPE(org.kie.pmml.api.enums.OP_TYPE) Output(org.dmg.pmml.Output) LinkedHashMap(java.util.LinkedHashMap) FieldName(org.dmg.pmml.FieldName) OpType(org.dmg.pmml.OpType) TestCase.assertNotNull(junit.framework.TestCase.assertNotNull) KiePMMLInternalException(org.kie.pmml.api.exceptions.KiePMMLInternalException) Map(java.util.Map) Expression(com.github.javaparser.ast.expr.Expression) Assert.fail(org.junit.Assert.fail) CompilationUnit(com.github.javaparser.ast.CompilationUnit) MiningField(org.dmg.pmml.MiningField) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) SUPPORTED_NORMALIZATION_METHODS(org.kie.pmml.models.regression.compiler.factories.KiePMMLClassificationTableFactory.SUPPORTED_NORMALIZATION_METHODS) JavaParserUtils(org.kie.pmml.compiler.commons.utils.JavaParserUtils) PMML(org.dmg.pmml.PMML) RegressionModel(org.dmg.pmml.regression.RegressionModel) PACKAGE_NAME(org.kie.pmml.commons.Constants.PACKAGE_NAME) Assert.assertTrue(org.junit.Assert.assertTrue) KIE_PMML_CLASSIFICATION_TABLE_TEMPLATE_JAVA(org.kie.pmml.models.regression.compiler.factories.KiePMMLClassificationTableFactory.KIE_PMML_CLASSIFICATION_TABLE_TEMPLATE_JAVA) IOException(java.io.IOException) DataDictionary(org.dmg.pmml.DataDictionary) Test(org.junit.Test) CodegenTestUtils.commonValidateCompilation(org.kie.pmml.compiler.commons.testutils.CodegenTestUtils.commonValidateCompilation) KiePMMLClassificationTable(org.kie.pmml.models.regression.model.KiePMMLClassificationTable) RegressionTable(org.dmg.pmml.regression.RegressionTable) Collectors(java.util.stream.Collectors) KiePMMLModelUtils.getGeneratedClassName(org.kie.pmml.commons.utils.KiePMMLModelUtils.getGeneratedClassName) MethodReferenceExpr(com.github.javaparser.ast.expr.MethodReferenceExpr) FileUtils.getFileContent(org.kie.test.util.filesystem.FileUtils.getFileContent) DataField(org.dmg.pmml.DataField) KiePMMLTableSourceCategory(org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory) MethodDeclaration(com.github.javaparser.ast.body.MethodDeclaration) CommonCompilationDTO(org.kie.pmml.compiler.api.dto.CommonCompilationDTO) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) UNSUPPORTED_NORMALIZATION_METHODS(org.kie.pmml.models.regression.compiler.factories.KiePMMLClassificationTableFactory.UNSUPPORTED_NORMALIZATION_METHODS) ClassOrInterfaceDeclaration(com.github.javaparser.ast.body.ClassOrInterfaceDeclaration) Assert.assertEquals(org.junit.Assert.assertEquals) MiningField(org.dmg.pmml.MiningField) DataDictionary(org.dmg.pmml.DataDictionary) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) RegressionTable(org.dmg.pmml.regression.RegressionTable) RegressionModel(org.dmg.pmml.regression.RegressionModel) DataField(org.dmg.pmml.DataField) MiningSchema(org.dmg.pmml.MiningSchema) Output(org.dmg.pmml.Output) KiePMMLTableSourceCategory(org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory) OutputField(org.dmg.pmml.OutputField) PMML(org.dmg.pmml.PMML) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map) Test(org.junit.Test)

Example 4 with RegressionCompilationDTO

use of org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO in project drools by kiegroup.

the class KiePMMLClassificationTableFactoryTest method getClassificationTable.

@Test
public void getClassificationTable() {
    RegressionTable regressionTableProf = getRegressionTable(3.5, "professional");
    RegressionTable regressionTableCler = getRegressionTable(27.4, "clerical");
    OutputField outputFieldCat = getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
    OutputField outputFieldNum = getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
    OutputField outputFieldPrev = getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
    String targetField = "targetField";
    DataField dataField = new DataField();
    dataField.setName(FieldName.create(targetField));
    dataField.setOpType(OpType.CATEGORICAL);
    DataDictionary dataDictionary = new DataDictionary();
    dataDictionary.addDataFields(dataField);
    RegressionModel regressionModel = new RegressionModel();
    regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
    regressionModel.addRegressionTables(regressionTableProf, regressionTableCler);
    regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
    Output output = new Output();
    output.addOutputFields(outputFieldCat, outputFieldNum, outputFieldPrev);
    regressionModel.setOutput(output);
    MiningField targetMiningField = new MiningField();
    targetMiningField.setUsageType(MiningField.UsageType.TARGET);
    targetMiningField.setName(dataField.getName());
    MiningSchema miningSchema = new MiningSchema();
    miningSchema.addMiningFields(targetMiningField);
    regressionModel.setMiningSchema(miningSchema);
    PMML pmml = new PMML();
    pmml.setDataDictionary(dataDictionary);
    pmml.addModels(regressionModel);
    final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
    final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
    KiePMMLClassificationTable retrieved = KiePMMLClassificationTableFactory.getClassificationTable(compilationDTO);
    assertNotNull(retrieved);
    assertEquals(regressionModel.getRegressionTables().size(), retrieved.getCategoryTableMap().size());
    regressionModel.getRegressionTables().forEach(regressionTable -> assertTrue(retrieved.getCategoryTableMap().containsKey(regressionTable.getTargetCategory().toString())));
    assertEquals(regressionModel.getNormalizationMethod().value(), retrieved.getRegressionNormalizationMethod().getName());
    assertEquals(OP_TYPE.CATEGORICAL, retrieved.getOpType());
    boolean isBinary = regressionModel.getRegressionTables().size() == 2;
    assertEquals(isBinary, retrieved.isBinary());
    assertEquals(isBinary, retrieved.isBinary());
    assertEquals(targetMiningField.getName().getValue(), retrieved.getTargetField());
}
Also used : MiningField(org.dmg.pmml.MiningField) DataDictionary(org.dmg.pmml.DataDictionary) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) RegressionTable(org.dmg.pmml.regression.RegressionTable) RegressionModel(org.dmg.pmml.regression.RegressionModel) DataField(org.dmg.pmml.DataField) MiningSchema(org.dmg.pmml.MiningSchema) Output(org.dmg.pmml.Output) OutputField(org.dmg.pmml.OutputField) PMML(org.dmg.pmml.PMML) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) KiePMMLClassificationTable(org.kie.pmml.models.regression.model.KiePMMLClassificationTable) Test(org.junit.Test)

Example 5 with RegressionCompilationDTO

use of org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO in project drools by kiegroup.

the class KiePMMLClassificationTableFactoryTest method setStaticGetter.

@Test
public void setStaticGetter() throws IOException {
    String variableName = "variableName";
    RegressionTable regressionTableProf = getRegressionTable(3.5, "professional");
    RegressionTable regressionTableCler = getRegressionTable(27.4, "clerical");
    OutputField outputFieldCat = getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
    OutputField outputFieldNum = getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
    OutputField outputFieldPrev = getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
    String targetField = "targetField";
    DataField dataField = new DataField();
    dataField.setName(FieldName.create(targetField));
    dataField.setOpType(OpType.CATEGORICAL);
    DataDictionary dataDictionary = new DataDictionary();
    dataDictionary.addDataFields(dataField);
    RegressionModel regressionModel = new RegressionModel();
    regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
    regressionModel.addRegressionTables(regressionTableProf, regressionTableCler);
    regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
    Output output = new Output();
    output.addOutputFields(outputFieldCat, outputFieldNum, outputFieldPrev);
    regressionModel.setOutput(output);
    MiningField miningField = new MiningField();
    miningField.setUsageType(MiningField.UsageType.TARGET);
    miningField.setName(dataField.getName());
    MiningSchema miningSchema = new MiningSchema();
    miningSchema.addMiningFields(miningField);
    regressionModel.setMiningSchema(miningSchema);
    PMML pmml = new PMML();
    pmml.setDataDictionary(dataDictionary);
    pmml.addModels(regressionModel);
    final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
    final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
    final LinkedHashMap<String, KiePMMLTableSourceCategory> regressionTablesMap = new LinkedHashMap<>();
    regressionModel.getRegressionTables().forEach(regressionTable -> {
        String key = "defpack." + regressionTable.getTargetCategory().toString().toUpperCase();
        KiePMMLTableSourceCategory value = new KiePMMLTableSourceCategory("", regressionTable.getTargetCategory().toString());
        regressionTablesMap.put(key, value);
    });
    final MethodDeclaration staticGetterMethod = STATIC_GETTER_METHOD.clone();
    KiePMMLClassificationTableFactory.setStaticGetter(compilationDTO, regressionTablesMap, staticGetterMethod, variableName);
    String text = getFileContent(TEST_02_SOURCE);
    MethodDeclaration expected = JavaParserUtils.parseMethod(text);
    assertTrue(JavaParserUtils.equalsNode(expected, staticGetterMethod));
}
Also used : MiningField(org.dmg.pmml.MiningField) MethodDeclaration(com.github.javaparser.ast.body.MethodDeclaration) DataDictionary(org.dmg.pmml.DataDictionary) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) RegressionTable(org.dmg.pmml.regression.RegressionTable) RegressionModel(org.dmg.pmml.regression.RegressionModel) LinkedHashMap(java.util.LinkedHashMap) DataField(org.dmg.pmml.DataField) MiningSchema(org.dmg.pmml.MiningSchema) Output(org.dmg.pmml.Output) KiePMMLTableSourceCategory(org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory) OutputField(org.dmg.pmml.OutputField) PMML(org.dmg.pmml.PMML) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) Test(org.junit.Test)

Aggregations

RegressionCompilationDTO (org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO)13 RegressionModel (org.dmg.pmml.regression.RegressionModel)11 Test (org.junit.Test)10 HasClassLoaderMock (org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock)10 DataDictionary (org.dmg.pmml.DataDictionary)9 DataField (org.dmg.pmml.DataField)9 MiningField (org.dmg.pmml.MiningField)9 MiningSchema (org.dmg.pmml.MiningSchema)9 PMML (org.dmg.pmml.PMML)9 RegressionTable (org.dmg.pmml.regression.RegressionTable)8 KiePMMLTableSourceCategory (org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory)6 MethodDeclaration (com.github.javaparser.ast.body.MethodDeclaration)4 HashMap (java.util.HashMap)4 Map (java.util.Map)4 Output (org.dmg.pmml.Output)4 OutputField (org.dmg.pmml.OutputField)4 KiePMMLClassificationTable (org.kie.pmml.models.regression.model.KiePMMLClassificationTable)4 ClassOrInterfaceDeclaration (com.github.javaparser.ast.body.ClassOrInterfaceDeclaration)3 LinkedHashMap (java.util.LinkedHashMap)3 CompilationUnit (com.github.javaparser.ast.CompilationUnit)2