use of org.kie.pmml.api.exceptions.KiePMMLException in project drools by kiegroup.
the class KiePMMLSegmentFactory method setConstructor.
static void setConstructor(final String segmentName, final String generatedClassName, final ConstructorDeclaration constructorDeclaration, final String kiePMMLModelClass, final boolean isInterpreted, final double weight) {
setConstructorSuperNameInvocation(generatedClassName, constructorDeclaration, segmentName);
final BlockStmt body = constructorDeclaration.getBody();
final ExplicitConstructorInvocationStmt superStatement = CommonCodegenUtils.getExplicitConstructorInvocationStmt(body).orElseThrow(() -> new KiePMMLException(String.format(MISSING_CONSTRUCTOR_IN_BODY, body)));
final Expression instantiationExpression = getInstantiationExpression(kiePMMLModelClass, isInterpreted);
String modelInstantiationString = instantiationExpression.toString();
CommonCodegenUtils.setExplicitConstructorInvocationStmtArgument(superStatement, "model", modelInstantiationString);
CommonCodegenUtils.setAssignExpressionValue(body, "weight", new DoubleLiteralExpr(weight));
CommonCodegenUtils.setAssignExpressionValue(body, "id", new StringLiteralExpr(segmentName));
}
use of org.kie.pmml.api.exceptions.KiePMMLException in project drools by kiegroup.
the class KiePMMLSegmentFactory method getSegmentSourcesMap.
static Map<String, String> getSegmentSourcesMap(final SegmentCompilationDTO segmentCompilationDTO, final boolean isInterpreted) {
logger.debug(GET_SEGMENT, segmentCompilationDTO.getSegment());
String kiePMMLModelClass = segmentCompilationDTO.getPackageCanonicalClassName();
final String className = getSanitizedClassName(segmentCompilationDTO.getId());
CompilationUnit cloneCU = JavaParserUtils.getKiePMMLModelCompilationUnit(className, segmentCompilationDTO.getPackageName(), KIE_PMML_SEGMENT_TEMPLATE_JAVA, KIE_PMML_SEGMENT_TEMPLATE);
ClassOrInterfaceDeclaration segmentTemplate = cloneCU.getClassByName(className).orElseThrow(() -> new KiePMMLException(MAIN_CLASS_NOT_FOUND + ": " + className));
final ConstructorDeclaration constructorDeclaration = segmentTemplate.getDefaultConstructor().orElseThrow(() -> new KiePMMLInternalException(String.format(MISSING_DEFAULT_CONSTRUCTOR, segmentTemplate.getName())));
final Map<String, String> toReturn = new HashMap<>();
setConstructor(segmentCompilationDTO.getId(), className, constructorDeclaration, kiePMMLModelClass, isInterpreted, segmentCompilationDTO.getWeight().doubleValue());
populateGetPredicateMethod(segmentCompilationDTO.getPredicate(), segmentCompilationDTO.getFields(), segmentTemplate);
toReturn.put(getFullClassName(cloneCU), cloneCU.toString());
return toReturn;
}
use of org.kie.pmml.api.exceptions.KiePMMLException in project drools by kiegroup.
the class KiePMMLSegmentFactory method getSegmentSourcesMap.
public static Map<String, String> getSegmentSourcesMap(final SegmentCompilationDTO segmentCompilationDTO, final List<KiePMMLModel> nestedModels) {
logger.debug(GET_SEGMENT, segmentCompilationDTO.getSegment());
final KiePMMLModel nestedModel = getFromCommonDataAndTransformationDictionaryAndModelWithSources(segmentCompilationDTO).orElseThrow(() -> new KiePMMLException("Failed to get the KiePMMLModel for segment " + segmentCompilationDTO.getModel().getModelName()));
final Map<String, String> toReturn = getSegmentSourcesMapCommon(segmentCompilationDTO, nestedModels, nestedModel);
segmentCompilationDTO.addFields(getFieldsFromModel(segmentCompilationDTO.getModel()));
return toReturn;
}
use of org.kie.pmml.api.exceptions.KiePMMLException in project drools by kiegroup.
the class HasKnowledgeBuilderMock method compileAndLoadClass.
@Override
public Class<?> compileAndLoadClass(Map<String, String> sourcesMap, String fullClassName) {
ClassLoader classLoader = getClassLoader();
if (!(classLoader instanceof ProjectClassLoader)) {
throw new IllegalStateException("Expected ProjectClassLoader, received " + classLoader.getClass().getName());
}
ProjectClassLoader projectClassLoader = (ProjectClassLoader) classLoader;
final Map<String, byte[]> byteCode = KieMemoryCompiler.compileNoLoad(sourcesMap, projectClassLoader);
byteCode.forEach(projectClassLoader::defineClass);
try {
return projectClassLoader.loadClass(fullClassName);
} catch (Exception e) {
throw new KiePMMLException(e);
}
}
use of org.kie.pmml.api.exceptions.KiePMMLException in project drools by kiegroup.
the class KiePMMLClassificationTableFactory method setStaticGetter.
// not-public code-generation
static void setStaticGetter(final RegressionCompilationDTO compilationDTO, final LinkedHashMap<String, KiePMMLTableSourceCategory> regressionTablesMap, final MethodDeclaration staticGetterMethod, final String variableName) {
final BlockStmt classificationTableBody = staticGetterMethod.getBody().orElseThrow(() -> new KiePMMLException(String.format(MISSING_BODY_TEMPLATE, staticGetterMethod)));
final VariableDeclarator variableDeclarator = getVariableDeclarator(classificationTableBody, TO_RETURN).orElseThrow(() -> new KiePMMLException(String.format(MISSING_VARIABLE_IN_BODY, TO_RETURN, classificationTableBody)));
final BlockStmt newBody = new BlockStmt();
final Map<String, Expression> regressionTableCategoriesMap = new LinkedHashMap<>();
regressionTablesMap.forEach((className, tableSourceCategory) -> {
MethodCallExpr methodCallExpr = new MethodCallExpr();
methodCallExpr.setScope(new NameExpr(className));
methodCallExpr.setName(KiePMMLRegressionTableFactory.GETKIEPMML_TABLE);
regressionTableCategoriesMap.put(tableSourceCategory.getCategory(), methodCallExpr);
});
// populate map
String categoryTableMapName = String.format(VARIABLE_NAME_TEMPLATE, CATEGORICAL_TABLE_MAP, variableName);
createPopulatedLinkedHashMap(newBody, categoryTableMapName, Arrays.asList(String.class.getSimpleName(), KiePMMLRegressionTable.class.getName()), regressionTableCategoriesMap);
final MethodCallExpr initializer = variableDeclarator.getInitializer().orElseThrow(() -> new KiePMMLException(String.format(MISSING_VARIABLE_INITIALIZER_TEMPLATE, TO_RETURN, classificationTableBody))).asMethodCallExpr();
final MethodCallExpr builder = getChainedMethodCallExprFrom("builder", initializer);
builder.setArgument(0, new StringLiteralExpr(variableName));
final REGRESSION_NORMALIZATION_METHOD regressionNormalizationMethod = compilationDTO.getDefaultREGRESSION_NORMALIZATION_METHOD();
getChainedMethodCallExprFrom("withRegressionNormalizationMethod", initializer).setArgument(0, new NameExpr(regressionNormalizationMethod.getClass().getSimpleName() + "." + regressionNormalizationMethod.name()));
OP_TYPE opType = compilationDTO.getOP_TYPE();
getChainedMethodCallExprFrom("withOpType", initializer).setArgument(0, new NameExpr(opType.getClass().getSimpleName() + "." + opType.name()));
getChainedMethodCallExprFrom("withCategoryTableMap", initializer).setArgument(0, new NameExpr(categoryTableMapName));
boolean isBinary = compilationDTO.isBinary(regressionTablesMap.size());
final Expression probabilityMapFunctionExpression = getProbabilityMapFunctionExpression(compilationDTO.getModelNormalizationMethod(), isBinary);
getChainedMethodCallExprFrom("withProbabilityMapFunction", initializer).setArgument(0, probabilityMapFunctionExpression);
getChainedMethodCallExprFrom("withIsBinary", initializer).setArgument(0, getExpressionForObject(isBinary));
getChainedMethodCallExprFrom("withTargetField", initializer).setArgument(0, getExpressionForObject(compilationDTO.getTargetFieldName()));
getChainedMethodCallExprFrom("withTargetCategory", initializer).setArgument(0, getExpressionForObject(null));
classificationTableBody.getStatements().forEach(newBody::addStatement);
staticGetterMethod.setBody(newBody);
}
Aggregations