use of org.kie.pmml.api.enums.MINING_FUNCTION in project drools by kiegroup.
the class KiePMMLModelFactoryUtils method init.
/**
* Initialize the given <code>ClassOrInterfaceDeclaration</code> with all the <b>common</b> code needed to
* generate a <code>KiePMMLModel</code>
* @param compilationDTO
* @param modelTemplate
*/
public static void init(final CompilationDTO<? extends Model> compilationDTO, final ClassOrInterfaceDeclaration modelTemplate) {
final ConstructorDeclaration constructorDeclaration = modelTemplate.getDefaultConstructor().orElseThrow(() -> new KiePMMLInternalException(String.format(MISSING_DEFAULT_CONSTRUCTOR, modelTemplate.getName())));
final String name = compilationDTO.getModelName();
final String generatedClassName = compilationDTO.getSimpleClassName();
final List<MiningField> miningFields = compilationDTO.getKieMiningFields();
final List<OutputField> outputFields = compilationDTO.getKieOutputFields();
final List<TargetField> targetFields = compilationDTO.getKieTargetFields();
final Expression miningFunctionExpression;
if (compilationDTO.getMINING_FUNCTION() != null) {
MINING_FUNCTION miningFunction = compilationDTO.getMINING_FUNCTION();
miningFunctionExpression = new NameExpr(miningFunction.getClass().getName() + "." + miningFunction.name());
} else {
miningFunctionExpression = new NullLiteralExpr();
}
final PMML_MODEL pmmlModelEnum = compilationDTO.getPMML_MODEL();
final NameExpr pmmlMODELExpression = new NameExpr(pmmlModelEnum.getClass().getName() + "." + pmmlModelEnum.name());
String targetFieldName = compilationDTO.getTargetFieldName();
final Expression targetFieldExpression;
if (targetFieldName != null) {
targetFieldExpression = new StringLiteralExpr(targetFieldName);
} else {
targetFieldExpression = new NullLiteralExpr();
}
setKiePMMLModelConstructor(generatedClassName, constructorDeclaration, name, miningFields, outputFields, targetFields);
addTransformationsInClassOrInterfaceDeclaration(modelTemplate, compilationDTO.getTransformationDictionary(), compilationDTO.getLocalTransformations());
final BlockStmt body = constructorDeclaration.getBody();
CommonCodegenUtils.setAssignExpressionValue(body, "pmmlMODEL", pmmlMODELExpression);
CommonCodegenUtils.setAssignExpressionValue(body, "miningFunction", miningFunctionExpression);
CommonCodegenUtils.setAssignExpressionValue(body, "targetField", targetFieldExpression);
addGetCreatedKiePMMLMiningFieldsMethod(modelTemplate, compilationDTO.getMiningSchema().getMiningFields(), compilationDTO.getFields());
MethodCallExpr getCreatedKiePMMLMiningFieldsExpr = new MethodCallExpr();
getCreatedKiePMMLMiningFieldsExpr.setScope(new ThisExpr());
getCreatedKiePMMLMiningFieldsExpr.setName(GET_CREATED_KIEPMMLMININGFIELDS);
CommonCodegenUtils.setAssignExpressionValue(body, "kiePMMLMiningFields", getCreatedKiePMMLMiningFieldsExpr);
if (compilationDTO.getOutput() != null) {
addGetCreatedKiePMMLOutputFieldsMethod(modelTemplate, compilationDTO.getOutput().getOutputFields());
MethodCallExpr getCreatedKiePMMLOutputFieldsExpr = new MethodCallExpr();
getCreatedKiePMMLOutputFieldsExpr.setScope(new ThisExpr());
getCreatedKiePMMLOutputFieldsExpr.setName(GET_CREATED_KIEPMMLOUTPUTFIELDS);
CommonCodegenUtils.setAssignExpressionValue(body, "kiePMMLOutputFields", getCreatedKiePMMLOutputFieldsExpr);
}
}
use of org.kie.pmml.api.enums.MINING_FUNCTION in project drools by kiegroup.
the class KiePMMLModelFactoryUtils method initStaticGetter.
/**
* Populate the given <code>ClassOrInterfaceDeclaration</code>' <b>staticGetter</b> with the <b>common</b>
* parameters needed to
* instantiate a <code>KiePMMLModel</code>
* @param compilationDTO
* @param modelTemplate
* @return
*/
public static void initStaticGetter(final CompilationDTO<? extends Model> compilationDTO, final ClassOrInterfaceDeclaration modelTemplate) {
final MethodDeclaration staticGetterMethod = modelTemplate.getMethodsByName(GET_MODEL).get(0);
final BlockStmt staticGetterBody = staticGetterMethod.getBody().orElseThrow(() -> new KiePMMLException(String.format(MISSING_BODY_TEMPLATE, staticGetterMethod)));
final VariableDeclarator variableDeclarator = getVariableDeclarator(staticGetterBody, TO_RETURN).orElseThrow(() -> new KiePMMLException(String.format(MISSING_VARIABLE_IN_BODY, TO_RETURN, staticGetterBody)));
final MethodCallExpr initializer = variableDeclarator.getInitializer().orElseThrow(() -> new KiePMMLException(String.format(MISSING_VARIABLE_INITIALIZER_TEMPLATE, TO_RETURN, staticGetterBody))).asMethodCallExpr();
final MethodCallExpr builder = getChainedMethodCallExprFrom("builder", initializer);
final String name = compilationDTO.getModelName();
final Expression miningFunctionExpression;
if (compilationDTO.getMINING_FUNCTION() != null) {
MINING_FUNCTION miningFunction = compilationDTO.getMINING_FUNCTION();
miningFunctionExpression = new NameExpr(miningFunction.getClass().getName() + "." + miningFunction.name());
} else {
miningFunctionExpression = new NullLiteralExpr();
}
builder.setArgument(0, new StringLiteralExpr(name));
builder.setArgument(1, miningFunctionExpression);
String targetFieldName = compilationDTO.getTargetFieldName();
final Expression targetFieldExpression;
if (targetFieldName != null) {
targetFieldExpression = new StringLiteralExpr(targetFieldName);
} else {
targetFieldExpression = new NullLiteralExpr();
}
getChainedMethodCallExprFrom("withTargetField", initializer).setArgument(0, targetFieldExpression);
//
populateGetCreatedMiningFieldsMethod(modelTemplate, compilationDTO.getKieMiningFields());
populateGetCreatedOutputFieldsMethod(modelTemplate, compilationDTO.getKieOutputFields());
populateGetCreatedKiePMMLMiningFieldsMethod(modelTemplate, compilationDTO.getMiningSchema().getMiningFields(), compilationDTO.getFields());
if (compilationDTO.getOutput() != null) {
populateGetCreatedKiePMMLOutputFieldsMethod(modelTemplate, compilationDTO.getOutput().getOutputFields());
}
if (compilationDTO.getKieTargetFields() != null) {
populateGetCreatedKiePMMLTargetsMethod(modelTemplate, compilationDTO.getKieTargetFields());
}
populateGetCreatedTransformationDictionaryMethod(modelTemplate, compilationDTO.getTransformationDictionary());
populateGetCreatedLocalTransformationsMethod(modelTemplate, compilationDTO.getLocalTransformations());
}
use of org.kie.pmml.api.enums.MINING_FUNCTION in project drools by kiegroup.
the class KiePMMLDroolsModelFactoryUtils method getKiePMMLModelCompilationUnit.
/**
* @param droolsCompilationDTO
* @param javaTemplate the name of the <b>file</b> to be used as template source
* @param modelClassName the name of the class used in the provided template
* @return
*/
public static <T extends Model> CompilationUnit getKiePMMLModelCompilationUnit(final DroolsCompilationDTO<T> droolsCompilationDTO, final String javaTemplate, final String modelClassName) {
logger.trace("getKiePMMLModelCompilationUnit {} {} {}", droolsCompilationDTO.getFields(), droolsCompilationDTO.getModel(), droolsCompilationDTO.getPackageName());
String className = droolsCompilationDTO.getSimpleClassName();
CompilationUnit cloneCU = JavaParserUtils.getKiePMMLModelCompilationUnit(className, droolsCompilationDTO.getPackageName(), javaTemplate, modelClassName);
ClassOrInterfaceDeclaration modelTemplate = cloneCU.getClassByName(className).orElseThrow(() -> new KiePMMLException(MAIN_CLASS_NOT_FOUND + ": " + className));
MINING_FUNCTION miningFunction = droolsCompilationDTO.getMINING_FUNCTION();
final ConstructorDeclaration constructorDeclaration = modelTemplate.getDefaultConstructor().orElseThrow(() -> new KiePMMLInternalException(String.format(MISSING_DEFAULT_CONSTRUCTOR, modelTemplate.getName())));
String targetField = droolsCompilationDTO.getTargetFieldName();
setConstructor(droolsCompilationDTO.getModel(), constructorDeclaration, modelTemplate.getName(), targetField, miningFunction, droolsCompilationDTO.getPackageName());
addFieldTypeMapPopulation(constructorDeclaration.getBody(), droolsCompilationDTO.getFieldTypeMap());
return cloneCU;
}
use of org.kie.pmml.api.enums.MINING_FUNCTION in project drools by kiegroup.
the class KiePMMLTreeModelFactoryTest method setConstructor.
@Test
public void setConstructor() {
final String targetField = "whatIdo";
final ClassOrInterfaceDeclaration modelTemplate = classOrInterfaceDeclaration.clone();
KnowledgeBuilderImpl knowledgeBuilder = new KnowledgeBuilderImpl();
final CommonCompilationDTO<TreeModel> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, treeModel, new HasKnowledgeBuilderMock(knowledgeBuilder));
final DroolsCompilationDTO<TreeModel> droolsCompilationDTO = DroolsCompilationDTO.fromCompilationDTO(compilationDTO, new HashMap<>());
KiePMMLTreeModelFactory.setConstructor(droolsCompilationDTO, modelTemplate);
Map<Integer, Expression> superInvocationExpressionsMap = new HashMap<>();
superInvocationExpressionsMap.put(0, new NameExpr(String.format("\"%s\"", treeModel.getModelName())));
superInvocationExpressionsMap.put(2, new NameExpr(String.format("\"%s\"", treeModel.getAlgorithmName())));
MINING_FUNCTION miningFunction = MINING_FUNCTION.byName(treeModel.getMiningFunction().value());
PMML_MODEL pmmlModel = PMML_MODEL.byName(treeModel.getClass().getSimpleName());
Map<String, Expression> assignExpressionMap = new HashMap<>();
assignExpressionMap.put("targetField", new StringLiteralExpr(targetField));
assignExpressionMap.put("miningFunction", new NameExpr(miningFunction.getClass().getName() + "." + miningFunction.name()));
assignExpressionMap.put("pmmlMODEL", new NameExpr(pmmlModel.getClass().getName() + "." + pmmlModel.name()));
ConstructorDeclaration constructorDeclaration = modelTemplate.getDefaultConstructor().get();
assertTrue(commonEvaluateConstructor(constructorDeclaration, getSanitizedClassName(treeModel.getModelName()), superInvocationExpressionsMap, assignExpressionMap));
}
use of org.kie.pmml.api.enums.MINING_FUNCTION in project drools by kiegroup.
the class KiePMMLMiningModelFactoryTest method setConstructor.
@Test
public void setConstructor() {
PMML_MODEL pmmlModel = PMML_MODEL.byName(MINING_MODEL.getClass().getSimpleName());
final ClassOrInterfaceDeclaration modelTemplate = MODEL_TEMPLATE.clone();
MINING_FUNCTION miningFunction = MINING_FUNCTION.byName(MINING_MODEL.getMiningFunction().value());
final CommonCompilationDTO<MiningModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, MINING_MODEL, new HasClassLoaderMock());
final MiningModelCompilationDTO compilationDTO = MiningModelCompilationDTO.fromCompilationDTO(source);
KiePMMLMiningModelFactory.setConstructor(compilationDTO, modelTemplate);
Map<Integer, Expression> superInvocationExpressionsMap = new HashMap<>();
superInvocationExpressionsMap.put(0, new NameExpr(String.format("\"%s\"", MINING_MODEL.getModelName())));
Map<String, Expression> assignExpressionMap = new HashMap<>();
assignExpressionMap.put("targetField", new StringLiteralExpr(targetFieldName));
assignExpressionMap.put("miningFunction", new NameExpr(miningFunction.getClass().getName() + "." + miningFunction.name()));
assignExpressionMap.put("pmmlMODEL", new NameExpr(pmmlModel.getClass().getName() + "." + pmmlModel.name()));
ClassOrInterfaceType kiePMMLSegmentationClass = parseClassOrInterfaceType(compilationDTO.getSegmentationCanonicalClassName());
ObjectCreationExpr objectCreationExpr = new ObjectCreationExpr();
objectCreationExpr.setType(kiePMMLSegmentationClass);
assignExpressionMap.put("segmentation", objectCreationExpr);
ConstructorDeclaration constructorDeclaration = modelTemplate.getDefaultConstructor().get();
assertTrue(commonEvaluateConstructor(constructorDeclaration, getSanitizedClassName(MINING_MODEL.getModelName()), superInvocationExpressionsMap, assignExpressionMap));
}
Aggregations