use of org.dmg.pmml.tree.TreeModel in project drools by kiegroup.
the class KiePMMLTreeModelFactoryTest method setConstructor.
@Test
public void setConstructor() {
final String targetField = "whatIdo";
final ClassOrInterfaceDeclaration modelTemplate = classOrInterfaceDeclaration.clone();
KnowledgeBuilderImpl knowledgeBuilder = new KnowledgeBuilderImpl();
final CommonCompilationDTO<TreeModel> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, treeModel, new HasKnowledgeBuilderMock(knowledgeBuilder));
final DroolsCompilationDTO<TreeModel> droolsCompilationDTO = DroolsCompilationDTO.fromCompilationDTO(compilationDTO, new HashMap<>());
KiePMMLTreeModelFactory.setConstructor(droolsCompilationDTO, modelTemplate);
Map<Integer, Expression> superInvocationExpressionsMap = new HashMap<>();
superInvocationExpressionsMap.put(0, new NameExpr(String.format("\"%s\"", treeModel.getModelName())));
superInvocationExpressionsMap.put(2, new NameExpr(String.format("\"%s\"", treeModel.getAlgorithmName())));
MINING_FUNCTION miningFunction = MINING_FUNCTION.byName(treeModel.getMiningFunction().value());
PMML_MODEL pmmlModel = PMML_MODEL.byName(treeModel.getClass().getSimpleName());
Map<String, Expression> assignExpressionMap = new HashMap<>();
assignExpressionMap.put("targetField", new StringLiteralExpr(targetField));
assignExpressionMap.put("miningFunction", new NameExpr(miningFunction.getClass().getName() + "." + miningFunction.name()));
assignExpressionMap.put("pmmlMODEL", new NameExpr(pmmlModel.getClass().getName() + "." + pmmlModel.name()));
ConstructorDeclaration constructorDeclaration = modelTemplate.getDefaultConstructor().get();
assertTrue(commonEvaluateConstructor(constructorDeclaration, getSanitizedClassName(treeModel.getModelName()), superInvocationExpressionsMap, assignExpressionMap));
}
use of org.dmg.pmml.tree.TreeModel in project drools by kiegroup.
the class KiePMMLTreeModelFactoryTest method setUp.
@BeforeClass
public static void setUp() throws Exception {
pmml = TestUtils.loadFromFile(SOURCE_1);
assertNotNull(pmml);
assertEquals(1, pmml.getModels().size());
assertTrue(pmml.getModels().get(0) instanceof TreeModel);
treeModel = (TreeModel) pmml.getModels().get(0);
CompilationUnit templateCU = getFromFileName(KIE_PMML_TREE_MODEL_TEMPLATE_JAVA);
classOrInterfaceDeclaration = templateCU.getClassByName(KIE_PMML_TREE_MODEL_TEMPLATE).get();
}
use of org.dmg.pmml.tree.TreeModel in project drools by kiegroup.
the class TreeModelImplementationProviderTest method getPMML.
private PMML getPMML(String source) throws Exception {
final FileInputStream fis = FileUtils.getFileInputStream(source);
final PMML toReturn = KiePMMLUtil.load(fis, source);
assertNotNull(toReturn);
assertEquals(1, toReturn.getModels().size());
assertTrue(toReturn.getModels().get(0) instanceof TreeModel);
return toReturn;
}
use of org.dmg.pmml.tree.TreeModel in project drools by kiegroup.
the class TreeModelImplementationProviderTest method getKiePMMLModel.
@Test
public void getKiePMMLModel() throws Exception {
final PMML pmml = getPMML(SOURCE_1);
KnowledgeBuilderImpl knowledgeBuilder = new KnowledgeBuilderImpl();
final CommonCompilationDTO<TreeModel> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, (TreeModel) pmml.getModels().get(0), new HasKnowledgeBuilderMock(knowledgeBuilder));
final KiePMMLTreeModel retrieved = PROVIDER.getKiePMMLModel(compilationDTO);
assertNotNull(retrieved);
commonVerifyIsDeepCloneable(retrieved);
}
use of org.dmg.pmml.tree.TreeModel in project drools by kiegroup.
the class KiePMMLNodeFactoryTest method setupClass.
@BeforeClass
public static void setupClass() throws Exception {
pmml1 = TestUtils.loadFromFile(SOURCE_1);
TreeModel model1 = (TreeModel) pmml1.getModels().get(0);
dataDictionary1 = pmml1.getDataDictionary();
derivedFields1 = getDerivedFields(pmml1.getTransformationDictionary(), model1.getLocalTransformations());
node1 = model1.getNode();
pmml2 = TestUtils.loadFromFile(SOURCE_2);
TreeModel model2 = (TreeModel) pmml2.getModels().get(0);
dataDictionary2 = pmml2.getDataDictionary();
derivedFields2 = getDerivedFields(pmml2.getTransformationDictionary(), model2.getLocalTransformations());
nodeRoot = model2.getNode();
compoundPredicateNode = nodeRoot.getNodes().get(0);
nodeLeaf = nodeRoot.getNodes().get(0).getNodes().get(0).getNodes().get(0);
}
Aggregations