Search in sources :

Example 1 with KiePMMLClassificationTable

use of org.kie.pmml.models.regression.model.KiePMMLClassificationTable in project drools by kiegroup.

the class KiePMMLClassificationTableFactoryTest method getClassificationTable.

@Test
public void getClassificationTable() {
    RegressionTable regressionTableProf = getRegressionTable(3.5, "professional");
    RegressionTable regressionTableCler = getRegressionTable(27.4, "clerical");
    OutputField outputFieldCat = getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
    OutputField outputFieldNum = getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
    OutputField outputFieldPrev = getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
    String targetField = "targetField";
    DataField dataField = new DataField();
    dataField.setName(FieldName.create(targetField));
    dataField.setOpType(OpType.CATEGORICAL);
    DataDictionary dataDictionary = new DataDictionary();
    dataDictionary.addDataFields(dataField);
    RegressionModel regressionModel = new RegressionModel();
    regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
    regressionModel.addRegressionTables(regressionTableProf, regressionTableCler);
    regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
    Output output = new Output();
    output.addOutputFields(outputFieldCat, outputFieldNum, outputFieldPrev);
    regressionModel.setOutput(output);
    MiningField targetMiningField = new MiningField();
    targetMiningField.setUsageType(MiningField.UsageType.TARGET);
    targetMiningField.setName(dataField.getName());
    MiningSchema miningSchema = new MiningSchema();
    miningSchema.addMiningFields(targetMiningField);
    regressionModel.setMiningSchema(miningSchema);
    PMML pmml = new PMML();
    pmml.setDataDictionary(dataDictionary);
    pmml.addModels(regressionModel);
    final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
    final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
    KiePMMLClassificationTable retrieved = KiePMMLClassificationTableFactory.getClassificationTable(compilationDTO);
    assertNotNull(retrieved);
    assertEquals(regressionModel.getRegressionTables().size(), retrieved.getCategoryTableMap().size());
    regressionModel.getRegressionTables().forEach(regressionTable -> assertTrue(retrieved.getCategoryTableMap().containsKey(regressionTable.getTargetCategory().toString())));
    assertEquals(regressionModel.getNormalizationMethod().value(), retrieved.getRegressionNormalizationMethod().getName());
    assertEquals(OP_TYPE.CATEGORICAL, retrieved.getOpType());
    boolean isBinary = regressionModel.getRegressionTables().size() == 2;
    assertEquals(isBinary, retrieved.isBinary());
    assertEquals(isBinary, retrieved.isBinary());
    assertEquals(targetMiningField.getName().getValue(), retrieved.getTargetField());
}
Also used : MiningField(org.dmg.pmml.MiningField) DataDictionary(org.dmg.pmml.DataDictionary) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) RegressionTable(org.dmg.pmml.regression.RegressionTable) RegressionModel(org.dmg.pmml.regression.RegressionModel) DataField(org.dmg.pmml.DataField) MiningSchema(org.dmg.pmml.MiningSchema) Output(org.dmg.pmml.Output) OutputField(org.dmg.pmml.OutputField) PMML(org.dmg.pmml.PMML) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) KiePMMLClassificationTable(org.kie.pmml.models.regression.model.KiePMMLClassificationTable) Test(org.junit.Test)

Example 2 with KiePMMLClassificationTable

use of org.kie.pmml.models.regression.model.KiePMMLClassificationTable in project drools by kiegroup.

the class KiePMMLRegressionModelFactoryTest method getKiePMMLRegressionModelClasses.

@Test
public void getKiePMMLRegressionModelClasses() throws IOException, IllegalAccessException, InstantiationException {
    final CompilationDTO<RegressionModel> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
    KiePMMLRegressionModel retrieved = KiePMMLRegressionModelFactory.getKiePMMLRegressionModelClasses(RegressionCompilationDTO.fromCompilationDTO(compilationDTO));
    assertNotNull(retrieved);
    assertEquals(regressionModel.getModelName(), retrieved.getName());
    assertEquals(MINING_FUNCTION.byName(regressionModel.getMiningFunction().value()), retrieved.getMiningFunction());
    assertEquals(miningFields.get(0).getName().getValue(), retrieved.getTargetField());
    final AbstractKiePMMLTable regressionTable = retrieved.getRegressionTable();
    assertNotNull(regressionTable);
    assertTrue(regressionTable instanceof KiePMMLClassificationTable);
    evaluateCategoricalRegressionTable((KiePMMLClassificationTable) regressionTable);
}
Also used : KiePMMLRegressionModel(org.kie.pmml.models.regression.model.KiePMMLRegressionModel) AbstractKiePMMLTable(org.kie.pmml.models.regression.model.AbstractKiePMMLTable) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) KiePMMLClassificationTable(org.kie.pmml.models.regression.model.KiePMMLClassificationTable) PMMLModelTestUtils.getRegressionModel(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRegressionModel) RegressionModel(org.dmg.pmml.regression.RegressionModel) KiePMMLRegressionModel(org.kie.pmml.models.regression.model.KiePMMLRegressionModel) Test(org.junit.Test)

Example 3 with KiePMMLClassificationTable

use of org.kie.pmml.models.regression.model.KiePMMLClassificationTable in project drools by kiegroup.

the class KiePMMLRegressionModelFactory method getKiePMMLRegressionModelSourcesMap.

// Source code generation
public static Map<String, String> getKiePMMLRegressionModelSourcesMap(final RegressionCompilationDTO compilationDTO) throws IOException {
    logger.trace("getKiePMMLRegressionModelSourcesMap {} {} {}", compilationDTO.getFields(), compilationDTO.getModel(), compilationDTO.getPackageName());
    String className = compilationDTO.getSimpleClassName();
    CompilationUnit cloneCU = JavaParserUtils.getKiePMMLModelCompilationUnit(className, compilationDTO.getPackageName(), KIE_PMML_REGRESSION_MODEL_TEMPLATE_JAVA, KIE_PMML_REGRESSION_MODEL_TEMPLATE);
    ClassOrInterfaceDeclaration modelTemplate = cloneCU.getClassByName(className).orElseThrow(() -> new KiePMMLException(MAIN_CLASS_NOT_FOUND + ": " + className));
    Map<String, KiePMMLTableSourceCategory> tablesSourceMap = getRegressionTablesMap(compilationDTO);
    String nestedTable = tablesSourceMap.size() == 1 ? tablesSourceMap.keySet().iterator().next() : tablesSourceMap.keySet().stream().filter(tableName -> tableName.startsWith(compilationDTO.getPackageName() + ".KiePMMLClassificationTable")).findFirst().orElseThrow(() -> new KiePMMLException("Failed to find expected " + "KiePMMLClassificationTable"));
    setStaticGetter(compilationDTO, modelTemplate, nestedTable);
    Map<String, String> toReturn = tablesSourceMap.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().getSource()));
    toReturn.put(getFullClassName(cloneCU), cloneCU.toString());
    return toReturn;
}
Also used : CompilationUnit(com.github.javaparser.ast.CompilationUnit) LoggerFactory(org.slf4j.LoggerFactory) HashMap(java.util.HashMap) MISSING_VARIABLE_IN_BODY(org.kie.pmml.commons.Constants.MISSING_VARIABLE_IN_BODY) VariableDeclarator(com.github.javaparser.ast.body.VariableDeclarator) MAIN_CLASS_NOT_FOUND(org.kie.pmml.compiler.commons.utils.JavaParserUtils.MAIN_CLASS_NOT_FOUND) GETKIEPMML_TABLE(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.GETKIEPMML_TABLE) Map(java.util.Map) CompilationUnit(com.github.javaparser.ast.CompilationUnit) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) AbstractKiePMMLTable(org.kie.pmml.models.regression.model.AbstractKiePMMLTable) CompilationDTO(org.kie.pmml.compiler.api.dto.CompilationDTO) JavaParserUtils(org.kie.pmml.compiler.commons.utils.JavaParserUtils) Logger(org.slf4j.Logger) RegressionModel(org.dmg.pmml.regression.RegressionModel) CommonCodegenUtils.getChainedMethodCallExprFrom(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.getChainedMethodCallExprFrom) KiePMMLModelFactoryUtils(org.kie.pmml.compiler.commons.codegenfactories.KiePMMLModelFactoryUtils) MethodCallExpr(com.github.javaparser.ast.expr.MethodCallExpr) JavaParserUtils.getFullClassName(org.kie.pmml.compiler.commons.utils.JavaParserUtils.getFullClassName) IOException(java.io.IOException) NameExpr(com.github.javaparser.ast.expr.NameExpr) KiePMMLClassificationTable(org.kie.pmml.models.regression.model.KiePMMLClassificationTable) RegressionTable(org.dmg.pmml.regression.RegressionTable) Collectors(java.util.stream.Collectors) CommonCodegenUtils.getVariableDeclarator(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.getVariableDeclarator) KiePMMLTableSourceCategory(org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory) List(java.util.List) KiePMMLRegressionModel(org.kie.pmml.models.regression.model.KiePMMLRegressionModel) CommonCodegenUtils.getMethodDeclarationBlockStmt(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.getMethodDeclarationBlockStmt) MISSING_VARIABLE_INITIALIZER_TEMPLATE(org.kie.pmml.commons.Constants.MISSING_VARIABLE_INITIALIZER_TEMPLATE) GET_MODEL(org.kie.pmml.commons.Constants.GET_MODEL) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) Collections(java.util.Collections) ClassOrInterfaceDeclaration(com.github.javaparser.ast.body.ClassOrInterfaceDeclaration) KiePMMLException(org.kie.pmml.api.exceptions.KiePMMLException) TO_RETURN(org.kie.pmml.commons.Constants.TO_RETURN) ClassOrInterfaceDeclaration(com.github.javaparser.ast.body.ClassOrInterfaceDeclaration) KiePMMLTableSourceCategory(org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory) KiePMMLException(org.kie.pmml.api.exceptions.KiePMMLException) HashMap(java.util.HashMap) Map(java.util.Map)

Example 4 with KiePMMLClassificationTable

use of org.kie.pmml.models.regression.model.KiePMMLClassificationTable in project drools by kiegroup.

the class KiePMMLRegressionModelFactory method getRegressionTables.

// not-public KiePMMLRegressionModel instantiation
static Map<String, AbstractKiePMMLTable> getRegressionTables(final RegressionCompilationDTO compilationDTO) {
    Map<String, AbstractKiePMMLTable> toReturn = new HashMap<>();
    if (compilationDTO.isRegression()) {
        final List<RegressionTable> regressionTables = Collections.singletonList(compilationDTO.getModel().getRegressionTables().get(0));
        final RegressionCompilationDTO regressionCompilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(compilationDTO, regressionTables, compilationDTO.getModel().getNormalizationMethod());
        toReturn.putAll(KiePMMLRegressionTableFactory.getRegressionTables(regressionCompilationDTO));
    } else {
        final List<RegressionTable> regressionTables = compilationDTO.getModel().getRegressionTables();
        final RegressionCompilationDTO regressionCompilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(compilationDTO, regressionTables, RegressionModel.NormalizationMethod.NONE);
        KiePMMLClassificationTable kiePMMLClassificationTable = KiePMMLClassificationTableFactory.getClassificationTable(regressionCompilationDTO);
        toReturn.put(kiePMMLClassificationTable.getName(), kiePMMLClassificationTable);
    }
    return toReturn;
}
Also used : HashMap(java.util.HashMap) AbstractKiePMMLTable(org.kie.pmml.models.regression.model.AbstractKiePMMLTable) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) KiePMMLClassificationTable(org.kie.pmml.models.regression.model.KiePMMLClassificationTable) RegressionTable(org.dmg.pmml.regression.RegressionTable)

Aggregations

KiePMMLClassificationTable (org.kie.pmml.models.regression.model.KiePMMLClassificationTable)4 RegressionModel (org.dmg.pmml.regression.RegressionModel)3 RegressionTable (org.dmg.pmml.regression.RegressionTable)3 RegressionCompilationDTO (org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO)3 AbstractKiePMMLTable (org.kie.pmml.models.regression.model.AbstractKiePMMLTable)3 HashMap (java.util.HashMap)2 Test (org.junit.Test)2 HasClassLoaderMock (org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock)2 KiePMMLRegressionModel (org.kie.pmml.models.regression.model.KiePMMLRegressionModel)2 CompilationUnit (com.github.javaparser.ast.CompilationUnit)1 ClassOrInterfaceDeclaration (com.github.javaparser.ast.body.ClassOrInterfaceDeclaration)1 VariableDeclarator (com.github.javaparser.ast.body.VariableDeclarator)1 MethodCallExpr (com.github.javaparser.ast.expr.MethodCallExpr)1 NameExpr (com.github.javaparser.ast.expr.NameExpr)1 BlockStmt (com.github.javaparser.ast.stmt.BlockStmt)1 IOException (java.io.IOException)1 Collections (java.util.Collections)1 List (java.util.List)1 Map (java.util.Map)1 Collectors (java.util.stream.Collectors)1