use of com.github.javaparser.ast.type.ClassOrInterfaceType in project drools by kiegroup.
the class KiePMMLSegmentationFactory method setConstructor.
static void setConstructor(final String generatedClassName, final String segmentationName, final ConstructorDeclaration constructorDeclaration, final MULTIPLE_MODEL_METHOD multipleModelMethod, final List<String> segmentsClasses) {
setConstructorSuperNameInvocation(generatedClassName, constructorDeclaration, segmentationName);
final BlockStmt body = constructorDeclaration.getBody();
final ExplicitConstructorInvocationStmt superStatement = CommonCodegenUtils.getExplicitConstructorInvocationStmt(body).orElseThrow(() -> new KiePMMLException(String.format(MISSING_CONSTRUCTOR_IN_BODY, body)));
CommonCodegenUtils.setExplicitConstructorInvocationStmtArgument(superStatement, "multipleModelMethod", multipleModelMethod.getClass().getCanonicalName() + "." + multipleModelMethod.name());
final List<AssignExpr> assignExprs = body.findAll(AssignExpr.class);
assignExprs.forEach(assignExpr -> {
if (assignExpr.getTarget().asNameExpr().getNameAsString().equals("segments")) {
for (String segmentClass : segmentsClasses) {
ClassOrInterfaceType kiePMMLSegmentClass = parseClassOrInterfaceType(segmentClass);
ObjectCreationExpr objectCreationExpr = new ObjectCreationExpr();
objectCreationExpr.setType(kiePMMLSegmentClass);
NodeList<Expression> arguments = NodeList.nodeList(objectCreationExpr);
MethodCallExpr methodCallExpr = new MethodCallExpr();
methodCallExpr.setScope(assignExpr.getTarget().asNameExpr());
methodCallExpr.setName("add");
methodCallExpr.setArguments(arguments);
ExpressionStmt expressionStmt = new ExpressionStmt();
expressionStmt.setExpression(methodCallExpr);
body.addStatement(expressionStmt);
}
}
});
}
use of com.github.javaparser.ast.type.ClassOrInterfaceType in project drools by kiegroup.
the class KiePMMLClassificationTableFactory method getProbabilityMapFunctionSupportedExpression.
/**
* Create <b>probabilityMapFunction</b> <code>MethodReferenceExpr</code>
* @param normalizationMethod
* @param isBinary
* @return
*/
static MethodReferenceExpr getProbabilityMapFunctionSupportedExpression(final RegressionModel.NormalizationMethod normalizationMethod, final boolean isBinary) {
String normalizationName = normalizationMethod.name();
if (RegressionModel.NormalizationMethod.NONE.equals(normalizationMethod) && isBinary) {
normalizationName += "Binary";
}
final String thisExpressionMethodName = String.format("get%sProbabilityMap", normalizationName);
final CastExpr castExpr = new CastExpr();
final String stringClassName = String.class.getSimpleName();
final String doubleClassName = Double.class.getSimpleName();
final ClassOrInterfaceType linkedHashMapReferenceType = getTypedClassOrInterfaceTypeByTypeNames(LinkedHashMap.class.getCanonicalName(), Arrays.asList(stringClassName, doubleClassName));
final ClassOrInterfaceType consumerType = getTypedClassOrInterfaceTypeByTypes(SerializableFunction.class.getCanonicalName(), Arrays.asList(linkedHashMapReferenceType, linkedHashMapReferenceType));
castExpr.setType(consumerType);
castExpr.setExpression("KiePMMLClassificationTable");
final MethodReferenceExpr toReturn = new MethodReferenceExpr();
toReturn.setScope(castExpr);
toReturn.setIdentifier(thisExpressionMethodName);
return toReturn;
}
use of com.github.javaparser.ast.type.ClassOrInterfaceType in project drools by kiegroup.
the class KiePMMLMiningModelFactoryTest method setConstructor.
@Test
public void setConstructor() {
PMML_MODEL pmmlModel = PMML_MODEL.byName(MINING_MODEL.getClass().getSimpleName());
final ClassOrInterfaceDeclaration modelTemplate = MODEL_TEMPLATE.clone();
MINING_FUNCTION miningFunction = MINING_FUNCTION.byName(MINING_MODEL.getMiningFunction().value());
final CommonCompilationDTO<MiningModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, MINING_MODEL, new HasClassLoaderMock());
final MiningModelCompilationDTO compilationDTO = MiningModelCompilationDTO.fromCompilationDTO(source);
KiePMMLMiningModelFactory.setConstructor(compilationDTO, modelTemplate);
Map<Integer, Expression> superInvocationExpressionsMap = new HashMap<>();
superInvocationExpressionsMap.put(0, new NameExpr(String.format("\"%s\"", MINING_MODEL.getModelName())));
Map<String, Expression> assignExpressionMap = new HashMap<>();
assignExpressionMap.put("targetField", new StringLiteralExpr(targetFieldName));
assignExpressionMap.put("miningFunction", new NameExpr(miningFunction.getClass().getName() + "." + miningFunction.name()));
assignExpressionMap.put("pmmlMODEL", new NameExpr(pmmlModel.getClass().getName() + "." + pmmlModel.name()));
ClassOrInterfaceType kiePMMLSegmentationClass = parseClassOrInterfaceType(compilationDTO.getSegmentationCanonicalClassName());
ObjectCreationExpr objectCreationExpr = new ObjectCreationExpr();
objectCreationExpr.setType(kiePMMLSegmentationClass);
assignExpressionMap.put("segmentation", objectCreationExpr);
ConstructorDeclaration constructorDeclaration = modelTemplate.getDefaultConstructor().get();
assertTrue(commonEvaluateConstructor(constructorDeclaration, getSanitizedClassName(MINING_MODEL.getModelName()), superInvocationExpressionsMap, assignExpressionMap));
}
use of com.github.javaparser.ast.type.ClassOrInterfaceType in project drools by kiegroup.
the class KiePMMLRegressionTableFactory method getCategoricalPredictorExpression.
/**
* Create <b>CategoricalPredictor</b> <code>CastExpr</code> to the class
* @param categoricalPredictorMapName
* @return
*/
static CastExpr getCategoricalPredictorExpression(final String categoricalPredictorMapName) {
final String lambdaExpressionMethodName = "evaluateCategoricalPredictor";
final String parameterName = "input";
final MethodCallExpr lambdaMethodCallExpr = new MethodCallExpr();
lambdaMethodCallExpr.setName(lambdaExpressionMethodName);
final NodeList<Expression> arguments = new NodeList<>();
arguments.add(0, new NameExpr(parameterName));
arguments.add(1, new NameExpr(categoricalPredictorMapName));
lambdaMethodCallExpr.setArguments(arguments);
final ExpressionStmt lambdaExpressionStmt = new ExpressionStmt(lambdaMethodCallExpr);
final LambdaExpr lambdaExpr = new LambdaExpr();
final Parameter lambdaParameter = new Parameter(new UnknownType(), parameterName);
lambdaExpr.setParameters(NodeList.nodeList(lambdaParameter));
lambdaExpr.setBody(lambdaExpressionStmt);
lambdaMethodCallExpr.setScope(new NameExpr(KiePMMLRegressionTable.class.getSimpleName()));
final ClassOrInterfaceType serializableFunctionType = getTypedClassOrInterfaceTypeByTypeNames(SerializableFunction.class.getCanonicalName(), Arrays.asList(String.class.getSimpleName(), Double.class.getSimpleName()));
final CastExpr toReturn = new CastExpr();
toReturn.setType(serializableFunctionType);
toReturn.setExpression(lambdaExpr);
return toReturn;
}
use of com.github.javaparser.ast.type.ClassOrInterfaceType in project drools by kiegroup.
the class Expressions method genContextType.
public static Expression genContextType(Map<String, Expression> fields) {
final ClassOrInterfaceType sie = parseClassOrInterfaceType(java.util.AbstractMap.SimpleImmutableEntry.class.getCanonicalName());
sie.setTypeArguments(parseClassOrInterfaceType(String.class.getCanonicalName()), parseClassOrInterfaceType(org.kie.dmn.feel.lang.Type.class.getCanonicalName()));
List<Expression> entryParams = fields.entrySet().stream().map(e -> new ObjectCreationExpr(null, sie, new NodeList<>(stringLiteral(e.getKey()), e.getValue()))).collect(Collectors.toList());
MethodCallExpr mOf = new MethodCallExpr(new NameExpr(java.util.stream.Stream.class.getCanonicalName()), "of");
entryParams.forEach(mOf::addArgument);
MethodCallExpr mCollect = new MethodCallExpr(mOf, "collect");
mCollect.addArgument(new MethodCallExpr(new NameExpr(java.util.stream.Collectors.class.getCanonicalName()), "toMap").addArgument(new MethodReferenceExpr(new NameExpr(java.util.Map.Entry.class.getCanonicalName()), new NodeList<>(), "getKey")).addArgument(new MethodReferenceExpr(new NameExpr(java.util.Map.Entry.class.getCanonicalName()), new NodeList<>(), "getValue")));
return new ObjectCreationExpr(null, MapBackedTypeT, new NodeList<>(stringLiteral("[anonymous]"), mCollect));
}
Aggregations