Search in sources :

Example 6 with CompilationDTO

use of org.kie.pmml.compiler.api.dto.CompilationDTO in project drools by kiegroup.

the class KiePMMLModelFactoryUtilsTest method populateGetCreatedOutputFieldsMethod.

@Test
public void populateGetCreatedOutputFieldsMethod() throws IOException {
    final CompilationDTO compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmmlModel, model, new HasClassLoaderMock());
    KiePMMLModelFactoryUtils.populateGetCreatedOutputFieldsMethod(classOrInterfaceDeclaration, compilationDTO.getKieOutputFields());
    final MethodDeclaration retrieved = classOrInterfaceDeclaration.getMethodsByName(GET_CREATED_OUTPUTFIELDS).get(0);
    String text = getFileContent(TEST_13_SOURCE);
    MethodDeclaration expected = JavaParserUtils.parseMethod(text);
    assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
}
Also used : CommonCompilationDTO(org.kie.pmml.compiler.api.dto.CommonCompilationDTO) CompilationDTO(org.kie.pmml.compiler.api.dto.CompilationDTO) MethodDeclaration(com.github.javaparser.ast.body.MethodDeclaration) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) Test(org.junit.Test)

Example 7 with CompilationDTO

use of org.kie.pmml.compiler.api.dto.CompilationDTO in project drools by kiegroup.

the class KiePMMLModelFactoryUtilsTest method addGetCreatedKiePMMLMiningFieldsMethod.

@Test
public void addGetCreatedKiePMMLMiningFieldsMethod() throws IOException {
    final CompilationDTO compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmmlModel, model, new HasClassLoaderMock());
    ClassOrInterfaceDeclaration modelTemplate = new ClassOrInterfaceDeclaration();
    KiePMMLModelFactoryUtils.addGetCreatedKiePMMLMiningFieldsMethod(modelTemplate, compilationDTO.getMiningSchema().getMiningFields(), compilationDTO.getFields());
    final MethodDeclaration retrieved = modelTemplate.getMethodsByName(GET_CREATED_KIEPMMLMININGFIELDS).get(0);
    String text = getFileContent(TEST_12_SOURCE);
    BlockStmt expected = JavaParserUtils.parseBlock(text);
    assertTrue(JavaParserUtils.equalsNode(expected, retrieved.getBody().get()));
}
Also used : CommonCompilationDTO(org.kie.pmml.compiler.api.dto.CommonCompilationDTO) CompilationDTO(org.kie.pmml.compiler.api.dto.CompilationDTO) ClassOrInterfaceDeclaration(com.github.javaparser.ast.body.ClassOrInterfaceDeclaration) MethodDeclaration(com.github.javaparser.ast.body.MethodDeclaration) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) Test(org.junit.Test)

Example 8 with CompilationDTO

use of org.kie.pmml.compiler.api.dto.CompilationDTO in project drools by kiegroup.

the class KiePMMLModelFactoryUtilsTest method commonPopulateGetCreatedKiePMMLOutputFieldsMethod.

@Test
public void commonPopulateGetCreatedKiePMMLOutputFieldsMethod() throws IOException {
    final CompilationDTO compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmmlModel, model, new HasClassLoaderMock());
    final MethodDeclaration methodDeclaration = new MethodDeclaration();
    KiePMMLModelFactoryUtils.commonPopulateGetCreatedKiePMMLOutputFieldsMethod(methodDeclaration, compilationDTO.getOutput().getOutputFields());
    String text = getFileContent(TEST_05_SOURCE);
    MethodDeclaration expected = JavaParserUtils.parseMethod(text);
    assertTrue(JavaParserUtils.equalsNode(expected, methodDeclaration));
}
Also used : CommonCompilationDTO(org.kie.pmml.compiler.api.dto.CommonCompilationDTO) CompilationDTO(org.kie.pmml.compiler.api.dto.CompilationDTO) MethodDeclaration(com.github.javaparser.ast.body.MethodDeclaration) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) Test(org.junit.Test)

Example 9 with CompilationDTO

use of org.kie.pmml.compiler.api.dto.CompilationDTO in project drools by kiegroup.

the class KiePMMLClusteringModelFactory method setConstructor.

static void setConstructor(final CompilationDTO<ClusteringModel> compilationDTO, final ClassOrInterfaceDeclaration modelTemplate) {
    KiePMMLModelFactoryUtils.init(compilationDTO, modelTemplate);
    final ConstructorDeclaration constructorDeclaration = modelTemplate.getDefaultConstructor().orElseThrow(() -> new KiePMMLInternalException(String.format(MISSING_DEFAULT_CONSTRUCTOR, modelTemplate.getName())));
    final BlockStmt body = constructorDeclaration.getBody();
    ClusteringModel clusteringModel = compilationDTO.getModel();
    body.addStatement(assignExprFrom("modelClass", modelClassFrom(clusteringModel.getModelClass())));
    clusteringModel.getClusters().stream().map(KiePMMLClusteringModelFactory::clusterCreationExprFrom).map(expr -> methodCallExprFrom("clusters", "add", expr)).forEach(body::addStatement);
    clusteringModel.getClusteringFields().stream().map(KiePMMLClusteringModelFactory::clusteringFieldCreationExprFrom).map(expr -> methodCallExprFrom("clusteringFields", "add", expr)).forEach(body::addStatement);
    body.addStatement(assignExprFrom("comparisonMeasure", comparisonMeasureCreationExprFrom(clusteringModel.getComparisonMeasure())));
    if (clusteringModel.getMissingValueWeights() != null) {
        body.addStatement(assignExprFrom("missingValueWeights", missingValueWeightsCreationExprFrom(clusteringModel.getMissingValueWeights())));
    }
}
Also used : KiePMMLCluster(org.kie.pmml.models.clustering.model.KiePMMLCluster) Arrays(java.util.Arrays) ClassOrInterfaceType(com.github.javaparser.ast.type.ClassOrInterfaceType) CommonCodegenUtils.methodCallExprFrom(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.methodCallExprFrom) LoggerFactory(org.slf4j.LoggerFactory) HashMap(java.util.HashMap) CommonCodegenUtils.assignExprFrom(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.assignExprFrom) CommonCodegenUtils.literalExprFrom(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.literalExprFrom) MAIN_CLASS_NOT_FOUND(org.kie.pmml.compiler.commons.utils.JavaParserUtils.MAIN_CLASS_NOT_FOUND) ConstructorDeclaration(com.github.javaparser.ast.body.ConstructorDeclaration) NullLiteralExpr(com.github.javaparser.ast.expr.NullLiteralExpr) KiePMMLInternalException(org.kie.pmml.api.exceptions.KiePMMLInternalException) MissingValueWeights(org.dmg.pmml.clustering.MissingValueWeights) KiePMMLComparisonMeasure(org.kie.pmml.models.clustering.model.KiePMMLComparisonMeasure) DoubleLiteralExpr(com.github.javaparser.ast.expr.DoubleLiteralExpr) ObjectCreationExpr(com.github.javaparser.ast.expr.ObjectCreationExpr) Map(java.util.Map) Expression(com.github.javaparser.ast.expr.Expression) ComparisonMeasure(org.dmg.pmml.ComparisonMeasure) CompilationUnit(com.github.javaparser.ast.CompilationUnit) KiePMMLClusteringModel(org.kie.pmml.models.clustering.model.KiePMMLClusteringModel) KiePMMLMissingValueWeights(org.kie.pmml.models.clustering.model.KiePMMLMissingValueWeights) NodeList(com.github.javaparser.ast.NodeList) CompilationDTO(org.kie.pmml.compiler.api.dto.CompilationDTO) KiePMMLClusteringField(org.kie.pmml.models.clustering.model.KiePMMLClusteringField) ClusteringField(org.dmg.pmml.clustering.ClusteringField) JavaParserUtils(org.kie.pmml.compiler.commons.utils.JavaParserUtils) Logger(org.slf4j.Logger) BooleanLiteralExpr(com.github.javaparser.ast.expr.BooleanLiteralExpr) KiePMMLModelFactoryUtils(org.kie.pmml.compiler.commons.codegenfactories.KiePMMLModelFactoryUtils) JavaParserUtils.getFullClassName(org.kie.pmml.compiler.commons.utils.JavaParserUtils.getFullClassName) KiePMMLClusteringConversionUtils.aggregateFunctionFrom(org.kie.pmml.models.clustering.compiler.factories.KiePMMLClusteringConversionUtils.aggregateFunctionFrom) KiePMMLClusteringConversionUtils.compareFunctionFrom(org.kie.pmml.models.clustering.compiler.factories.KiePMMLClusteringConversionUtils.compareFunctionFrom) Array(org.dmg.pmml.Array) KiePMMLClusteringConversionUtils.modelClassFrom(org.kie.pmml.models.clustering.compiler.factories.KiePMMLClusteringConversionUtils.modelClassFrom) Cluster(org.dmg.pmml.clustering.Cluster) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) MISSING_DEFAULT_CONSTRUCTOR(org.kie.pmml.commons.Constants.MISSING_DEFAULT_CONSTRUCTOR) KiePMMLClusteringConversionUtils.comparisonMeasureKindFrom(org.kie.pmml.models.clustering.compiler.factories.KiePMMLClusteringConversionUtils.comparisonMeasureKindFrom) ClassOrInterfaceDeclaration(com.github.javaparser.ast.body.ClassOrInterfaceDeclaration) KiePMMLException(org.kie.pmml.api.exceptions.KiePMMLException) ClusteringModel(org.dmg.pmml.clustering.ClusteringModel) ConstructorDeclaration(com.github.javaparser.ast.body.ConstructorDeclaration) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) KiePMMLInternalException(org.kie.pmml.api.exceptions.KiePMMLInternalException) KiePMMLClusteringModel(org.kie.pmml.models.clustering.model.KiePMMLClusteringModel) ClusteringModel(org.dmg.pmml.clustering.ClusteringModel)

Example 10 with CompilationDTO

use of org.kie.pmml.compiler.api.dto.CompilationDTO in project drools by kiegroup.

the class KiePMMLRegressionModelFactory method getKiePMMLRegressionModelSourcesMap.

// Source code generation
public static Map<String, String> getKiePMMLRegressionModelSourcesMap(final RegressionCompilationDTO compilationDTO) throws IOException {
    logger.trace("getKiePMMLRegressionModelSourcesMap {} {} {}", compilationDTO.getFields(), compilationDTO.getModel(), compilationDTO.getPackageName());
    String className = compilationDTO.getSimpleClassName();
    CompilationUnit cloneCU = JavaParserUtils.getKiePMMLModelCompilationUnit(className, compilationDTO.getPackageName(), KIE_PMML_REGRESSION_MODEL_TEMPLATE_JAVA, KIE_PMML_REGRESSION_MODEL_TEMPLATE);
    ClassOrInterfaceDeclaration modelTemplate = cloneCU.getClassByName(className).orElseThrow(() -> new KiePMMLException(MAIN_CLASS_NOT_FOUND + ": " + className));
    Map<String, KiePMMLTableSourceCategory> tablesSourceMap = getRegressionTablesMap(compilationDTO);
    String nestedTable = tablesSourceMap.size() == 1 ? tablesSourceMap.keySet().iterator().next() : tablesSourceMap.keySet().stream().filter(tableName -> tableName.startsWith(compilationDTO.getPackageName() + ".KiePMMLClassificationTable")).findFirst().orElseThrow(() -> new KiePMMLException("Failed to find expected " + "KiePMMLClassificationTable"));
    setStaticGetter(compilationDTO, modelTemplate, nestedTable);
    Map<String, String> toReturn = tablesSourceMap.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().getSource()));
    toReturn.put(getFullClassName(cloneCU), cloneCU.toString());
    return toReturn;
}
Also used : CompilationUnit(com.github.javaparser.ast.CompilationUnit) LoggerFactory(org.slf4j.LoggerFactory) HashMap(java.util.HashMap) MISSING_VARIABLE_IN_BODY(org.kie.pmml.commons.Constants.MISSING_VARIABLE_IN_BODY) VariableDeclarator(com.github.javaparser.ast.body.VariableDeclarator) MAIN_CLASS_NOT_FOUND(org.kie.pmml.compiler.commons.utils.JavaParserUtils.MAIN_CLASS_NOT_FOUND) GETKIEPMML_TABLE(org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory.GETKIEPMML_TABLE) Map(java.util.Map) CompilationUnit(com.github.javaparser.ast.CompilationUnit) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) AbstractKiePMMLTable(org.kie.pmml.models.regression.model.AbstractKiePMMLTable) CompilationDTO(org.kie.pmml.compiler.api.dto.CompilationDTO) JavaParserUtils(org.kie.pmml.compiler.commons.utils.JavaParserUtils) Logger(org.slf4j.Logger) RegressionModel(org.dmg.pmml.regression.RegressionModel) CommonCodegenUtils.getChainedMethodCallExprFrom(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.getChainedMethodCallExprFrom) KiePMMLModelFactoryUtils(org.kie.pmml.compiler.commons.codegenfactories.KiePMMLModelFactoryUtils) MethodCallExpr(com.github.javaparser.ast.expr.MethodCallExpr) JavaParserUtils.getFullClassName(org.kie.pmml.compiler.commons.utils.JavaParserUtils.getFullClassName) IOException(java.io.IOException) NameExpr(com.github.javaparser.ast.expr.NameExpr) KiePMMLClassificationTable(org.kie.pmml.models.regression.model.KiePMMLClassificationTable) RegressionTable(org.dmg.pmml.regression.RegressionTable) Collectors(java.util.stream.Collectors) CommonCodegenUtils.getVariableDeclarator(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.getVariableDeclarator) KiePMMLTableSourceCategory(org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory) List(java.util.List) KiePMMLRegressionModel(org.kie.pmml.models.regression.model.KiePMMLRegressionModel) CommonCodegenUtils.getMethodDeclarationBlockStmt(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.getMethodDeclarationBlockStmt) MISSING_VARIABLE_INITIALIZER_TEMPLATE(org.kie.pmml.commons.Constants.MISSING_VARIABLE_INITIALIZER_TEMPLATE) GET_MODEL(org.kie.pmml.commons.Constants.GET_MODEL) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) Collections(java.util.Collections) ClassOrInterfaceDeclaration(com.github.javaparser.ast.body.ClassOrInterfaceDeclaration) KiePMMLException(org.kie.pmml.api.exceptions.KiePMMLException) TO_RETURN(org.kie.pmml.commons.Constants.TO_RETURN) ClassOrInterfaceDeclaration(com.github.javaparser.ast.body.ClassOrInterfaceDeclaration) KiePMMLTableSourceCategory(org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory) KiePMMLException(org.kie.pmml.api.exceptions.KiePMMLException) HashMap(java.util.HashMap) Map(java.util.Map)

Aggregations

CompilationDTO (org.kie.pmml.compiler.api.dto.CompilationDTO)10 Test (org.junit.Test)8 CommonCompilationDTO (org.kie.pmml.compiler.api.dto.CommonCompilationDTO)8 HasClassLoaderMock (org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock)8 MethodDeclaration (com.github.javaparser.ast.body.MethodDeclaration)7 BlockStmt (com.github.javaparser.ast.stmt.BlockStmt)5 ClassOrInterfaceDeclaration (com.github.javaparser.ast.body.ClassOrInterfaceDeclaration)3 CompilationUnit (com.github.javaparser.ast.CompilationUnit)2 HashMap (java.util.HashMap)2 Map (java.util.Map)2 KiePMMLException (org.kie.pmml.api.exceptions.KiePMMLException)2 KiePMMLModelFactoryUtils (org.kie.pmml.compiler.commons.codegenfactories.KiePMMLModelFactoryUtils)2 JavaParserUtils (org.kie.pmml.compiler.commons.utils.JavaParserUtils)2 MAIN_CLASS_NOT_FOUND (org.kie.pmml.compiler.commons.utils.JavaParserUtils.MAIN_CLASS_NOT_FOUND)2 JavaParserUtils.getFullClassName (org.kie.pmml.compiler.commons.utils.JavaParserUtils.getFullClassName)2 Logger (org.slf4j.Logger)2 LoggerFactory (org.slf4j.LoggerFactory)2 NodeList (com.github.javaparser.ast.NodeList)1 ConstructorDeclaration (com.github.javaparser.ast.body.ConstructorDeclaration)1 VariableDeclarator (com.github.javaparser.ast.body.VariableDeclarator)1