use of org.kie.pmml.api.enums.PMML_MODEL in project drools by kiegroup.
the class KiePMMLModelFactoryUtils method init.
/**
* Initialize the given <code>ClassOrInterfaceDeclaration</code> with all the <b>common</b> code needed to
* generate a <code>KiePMMLModel</code>
* @param compilationDTO
* @param modelTemplate
*/
public static void init(final CompilationDTO<? extends Model> compilationDTO, final ClassOrInterfaceDeclaration modelTemplate) {
final ConstructorDeclaration constructorDeclaration = modelTemplate.getDefaultConstructor().orElseThrow(() -> new KiePMMLInternalException(String.format(MISSING_DEFAULT_CONSTRUCTOR, modelTemplate.getName())));
final String name = compilationDTO.getModelName();
final String generatedClassName = compilationDTO.getSimpleClassName();
final List<MiningField> miningFields = compilationDTO.getKieMiningFields();
final List<OutputField> outputFields = compilationDTO.getKieOutputFields();
final List<TargetField> targetFields = compilationDTO.getKieTargetFields();
final Expression miningFunctionExpression;
if (compilationDTO.getMINING_FUNCTION() != null) {
MINING_FUNCTION miningFunction = compilationDTO.getMINING_FUNCTION();
miningFunctionExpression = new NameExpr(miningFunction.getClass().getName() + "." + miningFunction.name());
} else {
miningFunctionExpression = new NullLiteralExpr();
}
final PMML_MODEL pmmlModelEnum = compilationDTO.getPMML_MODEL();
final NameExpr pmmlMODELExpression = new NameExpr(pmmlModelEnum.getClass().getName() + "." + pmmlModelEnum.name());
String targetFieldName = compilationDTO.getTargetFieldName();
final Expression targetFieldExpression;
if (targetFieldName != null) {
targetFieldExpression = new StringLiteralExpr(targetFieldName);
} else {
targetFieldExpression = new NullLiteralExpr();
}
setKiePMMLModelConstructor(generatedClassName, constructorDeclaration, name, miningFields, outputFields, targetFields);
addTransformationsInClassOrInterfaceDeclaration(modelTemplate, compilationDTO.getTransformationDictionary(), compilationDTO.getLocalTransformations());
final BlockStmt body = constructorDeclaration.getBody();
CommonCodegenUtils.setAssignExpressionValue(body, "pmmlMODEL", pmmlMODELExpression);
CommonCodegenUtils.setAssignExpressionValue(body, "miningFunction", miningFunctionExpression);
CommonCodegenUtils.setAssignExpressionValue(body, "targetField", targetFieldExpression);
addGetCreatedKiePMMLMiningFieldsMethod(modelTemplate, compilationDTO.getMiningSchema().getMiningFields(), compilationDTO.getFields());
MethodCallExpr getCreatedKiePMMLMiningFieldsExpr = new MethodCallExpr();
getCreatedKiePMMLMiningFieldsExpr.setScope(new ThisExpr());
getCreatedKiePMMLMiningFieldsExpr.setName(GET_CREATED_KIEPMMLMININGFIELDS);
CommonCodegenUtils.setAssignExpressionValue(body, "kiePMMLMiningFields", getCreatedKiePMMLMiningFieldsExpr);
if (compilationDTO.getOutput() != null) {
addGetCreatedKiePMMLOutputFieldsMethod(modelTemplate, compilationDTO.getOutput().getOutputFields());
MethodCallExpr getCreatedKiePMMLOutputFieldsExpr = new MethodCallExpr();
getCreatedKiePMMLOutputFieldsExpr.setScope(new ThisExpr());
getCreatedKiePMMLOutputFieldsExpr.setName(GET_CREATED_KIEPMMLOUTPUTFIELDS);
CommonCodegenUtils.setAssignExpressionValue(body, "kiePMMLOutputFields", getCreatedKiePMMLOutputFieldsExpr);
}
}
use of org.kie.pmml.api.enums.PMML_MODEL in project drools by kiegroup.
the class KiePMMLModelRetriever method getFromCommonDataAndTransformationDictionaryAndModelWithSourcesCommon.
static Optional<KiePMMLModel> getFromCommonDataAndTransformationDictionaryAndModelWithSourcesCommon(final List<Field<?>> fields, final Model model, final Function<ModelImplementationProvider<Model, KiePMMLModel>, KiePMMLModel> modelFunction) {
logger.trace("getFromCommonDataAndTransformationDictionaryAndModelWithSourcesCommon {}", model);
final PMML_MODEL pmmlMODEL = PMML_MODEL.byName(model.getClass().getSimpleName());
logger.debug("pmmlModelType {}", pmmlMODEL);
return getModelImplementationProviderStream(pmmlMODEL).map(modelFunction).findFirst();
}
use of org.kie.pmml.api.enums.PMML_MODEL in project drools by kiegroup.
the class KiePMMLTreeModelFactoryTest method setConstructor.
@Test
public void setConstructor() {
final String targetField = "whatIdo";
final ClassOrInterfaceDeclaration modelTemplate = classOrInterfaceDeclaration.clone();
KnowledgeBuilderImpl knowledgeBuilder = new KnowledgeBuilderImpl();
final CommonCompilationDTO<TreeModel> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, treeModel, new HasKnowledgeBuilderMock(knowledgeBuilder));
final DroolsCompilationDTO<TreeModel> droolsCompilationDTO = DroolsCompilationDTO.fromCompilationDTO(compilationDTO, new HashMap<>());
KiePMMLTreeModelFactory.setConstructor(droolsCompilationDTO, modelTemplate);
Map<Integer, Expression> superInvocationExpressionsMap = new HashMap<>();
superInvocationExpressionsMap.put(0, new NameExpr(String.format("\"%s\"", treeModel.getModelName())));
superInvocationExpressionsMap.put(2, new NameExpr(String.format("\"%s\"", treeModel.getAlgorithmName())));
MINING_FUNCTION miningFunction = MINING_FUNCTION.byName(treeModel.getMiningFunction().value());
PMML_MODEL pmmlModel = PMML_MODEL.byName(treeModel.getClass().getSimpleName());
Map<String, Expression> assignExpressionMap = new HashMap<>();
assignExpressionMap.put("targetField", new StringLiteralExpr(targetField));
assignExpressionMap.put("miningFunction", new NameExpr(miningFunction.getClass().getName() + "." + miningFunction.name()));
assignExpressionMap.put("pmmlMODEL", new NameExpr(pmmlModel.getClass().getName() + "." + pmmlModel.name()));
ConstructorDeclaration constructorDeclaration = modelTemplate.getDefaultConstructor().get();
assertTrue(commonEvaluateConstructor(constructorDeclaration, getSanitizedClassName(treeModel.getModelName()), superInvocationExpressionsMap, assignExpressionMap));
}
use of org.kie.pmml.api.enums.PMML_MODEL in project drools by kiegroup.
the class KiePMMLMiningModelFactoryTest method setConstructor.
@Test
public void setConstructor() {
PMML_MODEL pmmlModel = PMML_MODEL.byName(MINING_MODEL.getClass().getSimpleName());
final ClassOrInterfaceDeclaration modelTemplate = MODEL_TEMPLATE.clone();
MINING_FUNCTION miningFunction = MINING_FUNCTION.byName(MINING_MODEL.getMiningFunction().value());
final CommonCompilationDTO<MiningModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, MINING_MODEL, new HasClassLoaderMock());
final MiningModelCompilationDTO compilationDTO = MiningModelCompilationDTO.fromCompilationDTO(source);
KiePMMLMiningModelFactory.setConstructor(compilationDTO, modelTemplate);
Map<Integer, Expression> superInvocationExpressionsMap = new HashMap<>();
superInvocationExpressionsMap.put(0, new NameExpr(String.format("\"%s\"", MINING_MODEL.getModelName())));
Map<String, Expression> assignExpressionMap = new HashMap<>();
assignExpressionMap.put("targetField", new StringLiteralExpr(targetFieldName));
assignExpressionMap.put("miningFunction", new NameExpr(miningFunction.getClass().getName() + "." + miningFunction.name()));
assignExpressionMap.put("pmmlMODEL", new NameExpr(pmmlModel.getClass().getName() + "." + pmmlModel.name()));
ClassOrInterfaceType kiePMMLSegmentationClass = parseClassOrInterfaceType(compilationDTO.getSegmentationCanonicalClassName());
ObjectCreationExpr objectCreationExpr = new ObjectCreationExpr();
objectCreationExpr.setType(kiePMMLSegmentationClass);
assignExpressionMap.put("segmentation", objectCreationExpr);
ConstructorDeclaration constructorDeclaration = modelTemplate.getDefaultConstructor().get();
assertTrue(commonEvaluateConstructor(constructorDeclaration, getSanitizedClassName(MINING_MODEL.getModelName()), superInvocationExpressionsMap, assignExpressionMap));
}
use of org.kie.pmml.api.enums.PMML_MODEL in project drools by kiegroup.
the class KiePMMLDroolsModelFactoryUtils method setConstructor.
/**
* Define the <b>targetField</b>, the <b>miningFunction</b> and the <b>pmmlMODEL</b> inside the constructor
* @param model
* @param constructorDeclaration
* @param tableName
* @param targetField
* @param miningFunction
* @param kModulePackageName
*/
static void setConstructor(final Model model, final ConstructorDeclaration constructorDeclaration, final SimpleName tableName, final String targetField, final MINING_FUNCTION miningFunction, final String kModulePackageName) {
constructorDeclaration.setName(tableName);
final BlockStmt body = constructorDeclaration.getBody();
CommonCodegenUtils.setAssignExpressionValue(body, "targetField", new StringLiteralExpr(targetField));
CommonCodegenUtils.setAssignExpressionValue(body, "miningFunction", new NameExpr(miningFunction.getClass().getName() + "." + miningFunction.name()));
PMML_MODEL pmmlModel = PMML_MODEL.byName(model.getClass().getSimpleName());
CommonCodegenUtils.setAssignExpressionValue(body, "pmmlMODEL", new NameExpr(pmmlModel.getClass().getName() + "." + pmmlModel.name()));
CommonCodegenUtils.setAssignExpressionValue(body, "kModulePackageName", new StringLiteralExpr(kModulePackageName));
}
Aggregations