use of org.kie.pmml.commons.Constants.PACKAGE_NAME in project drools by kiegroup.
the class DroolsModelProviderTest method generateRulesFiles.
@Test
public void generateRulesFiles() {
KnowledgeBuilderImpl knowledgeBuilder = new KnowledgeBuilderImpl();
final CommonCompilationDTO<Scorecard> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, scorecard, new HasKnowledgeBuilderMock(knowledgeBuilder));
droolsModelProvider.getKiePMMLModelWithSources(compilationDTO);
String expectedPackageName = compilationDTO.getPackageName();
PackageDescr packageDescr = knowledgeBuilder.getPackageDescrs(expectedPackageName).get(0);
final List<GeneratedFile> retrieved = droolsModelProvider.generateRulesFiles(packageDescr);
assertNotNull(retrieved);
final String rootPath = expectedPackageName.replace('.', '/') + "/";
packageDescr.getTypeDeclarations().forEach(typeDeclarationDescr -> {
String expectedPath = rootPath + typeDeclarationDescr.getTypeName() + ".java";
assertTrue(retrieved.stream().anyMatch(generatedFile -> generatedFile.getPath().equals(expectedPath)));
});
String pkgUUID = packageDescr.getPreferredPkgUUID().get();
String expectedRule = rootPath + "Rules" + pkgUUID + ".java";
assertTrue(retrieved.stream().anyMatch(generatedFile -> generatedFile.getPath().equals(expectedRule)));
String expectedDomain = rootPath + "DomainClassesMetadata" + pkgUUID + ".java";
assertTrue(retrieved.stream().anyMatch(generatedFile -> generatedFile.getPath().equals(expectedDomain)));
}
use of org.kie.pmml.commons.Constants.PACKAGE_NAME in project drools by kiegroup.
the class DroolsModelProviderTest method getKiePMMLModelWithKnowledgeBuilder.
@Test
public void getKiePMMLModelWithKnowledgeBuilder() {
KnowledgeBuilderImpl knowledgeBuilder = new KnowledgeBuilderImpl();
final CommonCompilationDTO<Scorecard> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, scorecard, new HasKnowledgeBuilderMock(knowledgeBuilder));
KiePMMLDroolsModel retrieved = droolsModelProvider.getKiePMMLModel(compilationDTO);
assertNotNull(retrieved);
assertTrue(retrieved instanceof KiePMMLDroolsModelTest);
KiePMMLDroolsModelTest retrievedTest = (KiePMMLDroolsModelTest) retrieved;
final List<DataField> originalDataFields = pmml.getDataDictionary().getDataFields();
final List<DataField> retrievedDataFields = retrievedTest.dataDictionary.getDataFields();
assertEquals(originalDataFields.size(), retrievedTest.dataDictionary.getDataFields().size());
originalDataFields.forEach(dataField -> {
Optional<DataField> optRet = retrievedDataFields.stream().filter(retrievedDataField -> dataField.getName().equals(retrievedDataField.getName())).findFirst();
assertTrue(optRet.isPresent());
assertEquals(dataField.getDataType(), optRet.get().getDataType());
});
assertEquals(pmml.getTransformationDictionary(), retrievedTest.transformationDictionary);
assertEquals(scorecard, retrievedTest.model);
String expectedPackageName = getSanitizedPackageName(PACKAGE_NAME);
assertEquals(expectedPackageName, retrievedTest.getKModulePackageName());
assertEquals(PACKAGE_NAME, retrievedTest.getName());
PackageDescr packageDescr = knowledgeBuilder.getPackageDescrs("packagename").get(0);
assertTrue(packageDescr instanceof CompositePackageDescr);
}
use of org.kie.pmml.commons.Constants.PACKAGE_NAME in project drools by kiegroup.
the class KiePMMLClassificationTableFactoryTest method getClassificationTableBuilders.
@Test
public void getClassificationTableBuilders() {
RegressionTable regressionTableProf = getRegressionTable(3.5, "professional");
RegressionTable regressionTableCler = getRegressionTable(27.4, "clerical");
OutputField outputFieldCat = getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
OutputField outputFieldNum = getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
OutputField outputFieldPrev = getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
String targetField = "targetField";
DataField dataField = new DataField();
dataField.setName(FieldName.create(targetField));
dataField.setOpType(OpType.CATEGORICAL);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
RegressionModel regressionModel = new RegressionModel();
regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
regressionModel.addRegressionTables(regressionTableProf, regressionTableCler);
regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
Output output = new Output();
output.addOutputFields(outputFieldCat, outputFieldNum, outputFieldPrev);
regressionModel.setOutput(output);
MiningField miningField = new MiningField();
miningField.setUsageType(MiningField.UsageType.TARGET);
miningField.setName(dataField.getName());
MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
regressionModel.setMiningSchema(miningSchema);
PMML pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.addModels(regressionModel);
final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
Map<String, KiePMMLTableSourceCategory> retrieved = KiePMMLClassificationTableFactory.getClassificationTableBuilders(compilationDTO);
assertNotNull(retrieved);
assertEquals(3, retrieved.size());
retrieved.values().forEach(kiePMMLTableSourceCategory -> commonValidateKiePMMLRegressionTable(kiePMMLTableSourceCategory.getSource()));
Map<String, String> sources = retrieved.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, stringKiePMMLTableSourceCategoryEntry -> stringKiePMMLTableSourceCategoryEntry.getValue().getSource()));
commonValidateCompilation(sources);
}
use of org.kie.pmml.commons.Constants.PACKAGE_NAME in project drools by kiegroup.
the class KiePMMLMiningModelFactoryTest method getKiePMMLMiningModelSourcesMapCompiled.
@Test
public void getKiePMMLMiningModelSourcesMapCompiled() {
final List<KiePMMLModel> nestedModels = new ArrayList<>();
final HasKnowledgeBuilderMock hasKnowledgeBuilderMock = new HasKnowledgeBuilderMock(KNOWLEDGE_BUILDER);
final CommonCompilationDTO<MiningModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, MINING_MODEL, hasKnowledgeBuilderMock);
final MiningModelCompilationDTO compilationDTO = MiningModelCompilationDTO.fromCompilationDTO(source);
final List<String> expectedGeneratedClasses = MINING_MODEL.getSegmentation().getSegments().stream().map(segment -> {
String modelName = segment.getModel().getModelName();
String sanitizedPackageName = getSanitizedPackageName(compilationDTO.getSegmentationPackageName() + "." + segment.getId());
String sanitizedClassName = getSanitizedClassName(modelName);
return String.format(PACKAGE_CLASS_TEMPLATE, sanitizedPackageName, sanitizedClassName);
}).collect(Collectors.toList());
expectedGeneratedClasses.forEach(expectedGeneratedClass -> {
try {
hasKnowledgeBuilderMock.getClassLoader().loadClass(expectedGeneratedClass);
fail("Expecting class not found: " + expectedGeneratedClass);
} catch (Exception e) {
assertTrue(e instanceof ClassNotFoundException);
}
});
final Map<String, String> retrieved = KiePMMLMiningModelFactory.getKiePMMLMiningModelSourcesMapCompiled(compilationDTO, nestedModels);
assertNotNull(retrieved);
int expectedNestedModels = MINING_MODEL.getSegmentation().getSegments().size();
assertEquals(expectedNestedModels, nestedModels.size());
expectedGeneratedClasses.forEach(expectedGeneratedClass -> {
try {
hasKnowledgeBuilderMock.getClassLoader().loadClass(expectedGeneratedClass);
} catch (Exception e) {
fail("Expecting class to be loaded, but got: " + e.getClass().getName() + " -> " + e.getMessage());
e.printStackTrace();
}
});
}
Aggregations