use of org.kie.pmml.models.regression.model.enums.REGRESSION_NORMALIZATION_METHOD in project drools by kiegroup.
the class KiePMMLClassificationTableFactory method setStaticGetter.
// not-public code-generation
static void setStaticGetter(final RegressionCompilationDTO compilationDTO, final LinkedHashMap<String, KiePMMLTableSourceCategory> regressionTablesMap, final MethodDeclaration staticGetterMethod, final String variableName) {
final BlockStmt classificationTableBody = staticGetterMethod.getBody().orElseThrow(() -> new KiePMMLException(String.format(MISSING_BODY_TEMPLATE, staticGetterMethod)));
final VariableDeclarator variableDeclarator = getVariableDeclarator(classificationTableBody, TO_RETURN).orElseThrow(() -> new KiePMMLException(String.format(MISSING_VARIABLE_IN_BODY, TO_RETURN, classificationTableBody)));
final BlockStmt newBody = new BlockStmt();
final Map<String, Expression> regressionTableCategoriesMap = new LinkedHashMap<>();
regressionTablesMap.forEach((className, tableSourceCategory) -> {
MethodCallExpr methodCallExpr = new MethodCallExpr();
methodCallExpr.setScope(new NameExpr(className));
methodCallExpr.setName(KiePMMLRegressionTableFactory.GETKIEPMML_TABLE);
regressionTableCategoriesMap.put(tableSourceCategory.getCategory(), methodCallExpr);
});
// populate map
String categoryTableMapName = String.format(VARIABLE_NAME_TEMPLATE, CATEGORICAL_TABLE_MAP, variableName);
createPopulatedLinkedHashMap(newBody, categoryTableMapName, Arrays.asList(String.class.getSimpleName(), KiePMMLRegressionTable.class.getName()), regressionTableCategoriesMap);
final MethodCallExpr initializer = variableDeclarator.getInitializer().orElseThrow(() -> new KiePMMLException(String.format(MISSING_VARIABLE_INITIALIZER_TEMPLATE, TO_RETURN, classificationTableBody))).asMethodCallExpr();
final MethodCallExpr builder = getChainedMethodCallExprFrom("builder", initializer);
builder.setArgument(0, new StringLiteralExpr(variableName));
final REGRESSION_NORMALIZATION_METHOD regressionNormalizationMethod = compilationDTO.getDefaultREGRESSION_NORMALIZATION_METHOD();
getChainedMethodCallExprFrom("withRegressionNormalizationMethod", initializer).setArgument(0, new NameExpr(regressionNormalizationMethod.getClass().getSimpleName() + "." + regressionNormalizationMethod.name()));
OP_TYPE opType = compilationDTO.getOP_TYPE();
getChainedMethodCallExprFrom("withOpType", initializer).setArgument(0, new NameExpr(opType.getClass().getSimpleName() + "." + opType.name()));
getChainedMethodCallExprFrom("withCategoryTableMap", initializer).setArgument(0, new NameExpr(categoryTableMapName));
boolean isBinary = compilationDTO.isBinary(regressionTablesMap.size());
final Expression probabilityMapFunctionExpression = getProbabilityMapFunctionExpression(compilationDTO.getModelNormalizationMethod(), isBinary);
getChainedMethodCallExprFrom("withProbabilityMapFunction", initializer).setArgument(0, probabilityMapFunctionExpression);
getChainedMethodCallExprFrom("withIsBinary", initializer).setArgument(0, getExpressionForObject(isBinary));
getChainedMethodCallExprFrom("withTargetField", initializer).setArgument(0, getExpressionForObject(compilationDTO.getTargetFieldName()));
getChainedMethodCallExprFrom("withTargetCategory", initializer).setArgument(0, getExpressionForObject(null));
classificationTableBody.getStatements().forEach(newBody::addStatement);
staticGetterMethod.setBody(newBody);
}
Aggregations