use of org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock in project drools by kiegroup.
the class TreeModelImplementationProviderTest method getKiePMMLModelWithSources.
@Test
public void getKiePMMLModelWithSources() {
TreeModel treeModel = (TreeModel) pmml.getModels().get(0);
final CommonCompilationDTO<TreeModel> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, treeModel, new HasClassLoaderMock());
final KiePMMLModelWithSources retrieved = PROVIDER.getKiePMMLModelWithSources(compilationDTO);
assertNotNull(retrieved);
final Map<String, String> sourcesMap = retrieved.getSourcesMap();
assertNotNull(sourcesMap);
assertFalse(sourcesMap.isEmpty());
ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
try {
final Map<String, Class<?>> compiled = KieMemoryCompiler.compile(sourcesMap, classLoader);
for (Class<?> clazz : compiled.values()) {
assertTrue(clazz instanceof Serializable);
}
} catch (Throwable t) {
fail(t.getMessage());
}
}
use of org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock in project drools by kiegroup.
the class KiePMMLTreeModelFactoryTest method setConstructor.
@Test
public void setConstructor() {
String className = getSanitizedClassName(treeModel1.getModelName());
CompilationUnit cloneCU = JavaParserUtils.getKiePMMLModelCompilationUnit(className, PACKAGE_NAME, KIE_PMML_TREE_MODEL_TEMPLATE_JAVA, KIE_PMML_TREE_MODEL_TEMPLATE);
ClassOrInterfaceDeclaration modelTemplate = cloneCU.getClassByName(className).orElseThrow(() -> new KiePMMLException(MAIN_CLASS_NOT_FOUND + ": " + className));
String targetField = "whatIdo";
String fullNodeClassName = "full.Node.ClassName";
CommonCompilationDTO<TreeModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml1, treeModel1, new HasClassLoaderMock());
KiePMMLTreeModelFactory.setConstructor(TreeCompilationDTO.fromCompilationDTO(source), modelTemplate, fullNodeClassName);
ConstructorDeclaration constructorDeclaration = modelTemplate.getDefaultConstructor().orElseThrow(() -> new KiePMMLInternalException(String.format(MISSING_DEFAULT_CONSTRUCTOR, modelTemplate.getName()))).clone();
BlockStmt body = constructorDeclaration.getBody();
// targetField
Optional<AssignExpr> optRetrieved = CommonCodegenUtils.getAssignExpression(body, "targetField");
assertTrue(optRetrieved.isPresent());
AssignExpr retrieved = optRetrieved.get();
Expression initializer = retrieved.getValue();
assertTrue(initializer instanceof StringLiteralExpr);
String expected = String.format("\"%s\"", targetField);
assertEquals(expected, initializer.toString());
// miningFunction
optRetrieved = CommonCodegenUtils.getAssignExpression(body, "miningFunction");
assertTrue(optRetrieved.isPresent());
retrieved = optRetrieved.get();
initializer = retrieved.getValue();
assertTrue(initializer instanceof NameExpr);
MINING_FUNCTION miningFunction = MINING_FUNCTION.byName(treeModel1.getMiningFunction().value());
expected = miningFunction.getClass().getName() + "." + miningFunction.name();
assertEquals(expected, initializer.toString());
// pmmlMODEL
optRetrieved = CommonCodegenUtils.getAssignExpression(body, "pmmlMODEL");
assertTrue(optRetrieved.isPresent());
retrieved = optRetrieved.get();
initializer = retrieved.getValue();
assertTrue(initializer instanceof NameExpr);
expected = PMML_MODEL.TREE_MODEL.getClass().getName() + "." + PMML_MODEL.TREE_MODEL.name();
assertEquals(expected, initializer.toString());
// nodeFunction
optRetrieved = CommonCodegenUtils.getAssignExpression(body, "nodeFunction");
assertTrue(optRetrieved.isPresent());
retrieved = optRetrieved.get();
initializer = retrieved.getValue();
assertTrue(initializer instanceof MethodReferenceExpr);
expected = fullNodeClassName;
assertEquals(expected, ((MethodReferenceExpr) initializer).getScope().toString());
expected = "evaluateNode";
assertEquals(expected, ((MethodReferenceExpr) initializer).getIdentifier());
}
use of org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock in project drools by kiegroup.
the class ClusteringModelImplementationProviderTest method getKiePMMLModelWithSources.
@Test
public void getKiePMMLModelWithSources() throws Exception {
PMML pmml = TestUtils.loadFromFile(SOURCE_FILE);
ClusteringModel model = getModel(pmml);
final CommonCompilationDTO<ClusteringModel> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, model, new HasClassLoaderMock());
KiePMMLModelWithSources retrieved = PROVIDER.getKiePMMLModelWithSources(compilationDTO);
assertNotNull(retrieved);
Map<String, String> sourcesMap = retrieved.getSourcesMap();
assertNotNull(sourcesMap);
assertFalse(sourcesMap.isEmpty());
ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
Map<String, Class<?>> compiled = KieMemoryCompiler.compile(sourcesMap, classLoader);
for (Class<?> clazz : compiled.values()) {
assertTrue(clazz instanceof Serializable);
}
}
use of org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock in project drools by kiegroup.
the class KiePMMLDroolsModelFactoryUtilsTest method getKiePMMLModelCompilationUnit.
@Test
public void getKiePMMLModelCompilationUnit() {
DataDictionary dataDictionary = new DataDictionary();
String targetFieldString = "target.field";
FieldName targetFieldName = FieldName.create(targetFieldString);
dataDictionary.addDataFields(new DataField(targetFieldName, OpType.CONTINUOUS, DataType.DOUBLE));
String modelName = "ModelName";
TreeModel model = new TreeModel();
model.setModelName(modelName);
model.setMiningFunction(MiningFunction.CLASSIFICATION);
MiningField targetMiningField = new MiningField(targetFieldName);
targetMiningField.setUsageType(MiningField.UsageType.TARGET);
MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(targetMiningField);
model.setMiningSchema(miningSchema);
Map<String, KiePMMLOriginalTypeGeneratedType> fieldTypeMap = new HashMap<>();
fieldTypeMap.put(targetFieldString, new KiePMMLOriginalTypeGeneratedType(targetFieldString, getSanitizedClassName(targetFieldString)));
String packageName = "net.test";
PMML pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.addModels(model);
final CommonCompilationDTO<TreeModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(packageName, pmml, model, new HasClassLoaderMock());
final DroolsCompilationDTO<TreeModel> droolsCompilationDTO = DroolsCompilationDTO.fromCompilationDTO(source, fieldTypeMap);
CompilationUnit retrieved = KiePMMLDroolsModelFactoryUtils.getKiePMMLModelCompilationUnit(droolsCompilationDTO, TEMPLATE_SOURCE, TEMPLATE_CLASS_NAME);
assertEquals(droolsCompilationDTO.getPackageName(), retrieved.getPackageDeclaration().get().getNameAsString());
ConstructorDeclaration constructorDeclaration = retrieved.getClassByName(modelName).get().getDefaultConstructor().get();
MINING_FUNCTION miningFunction = MINING_FUNCTION.CLASSIFICATION;
PMML_MODEL pmmlModel = PMML_MODEL.byName(model.getClass().getSimpleName());
Map<String, Expression> assignExpressionMap = new HashMap<>();
assignExpressionMap.put("targetField", new StringLiteralExpr(targetFieldString));
assignExpressionMap.put("miningFunction", new NameExpr(miningFunction.getClass().getName() + "." + miningFunction.name()));
assignExpressionMap.put("pmmlMODEL", new NameExpr(pmmlModel.getClass().getName() + "." + pmmlModel.name()));
String expectedKModulePackageName = getSanitizedPackageName(packageName + "." + modelName);
assignExpressionMap.put("kModulePackageName", new StringLiteralExpr(expectedKModulePackageName));
assertTrue(commonEvaluateAssignExpr(constructorDeclaration.getBody(), assignExpressionMap));
// The last "1" is for
int expectedMethodCallExprs = assignExpressionMap.size() + fieldTypeMap.size() + 1;
// the super invocation
commonEvaluateFieldTypeMap(constructorDeclaration.getBody(), fieldTypeMap, expectedMethodCallExprs);
}
use of org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock in project drools by kiegroup.
the class DroolsModelProviderTest method getKiePMMLModelNoKnowledgeBuilder.
@Test(expected = KiePMMLException.class)
public void getKiePMMLModelNoKnowledgeBuilder() {
final CommonCompilationDTO<Scorecard> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, scorecard, new HasClassLoaderMock());
droolsModelProvider.getKiePMMLModel(compilationDTO);
}
Aggregations