use of org.kie.pmml.commons.model.KiePMMLModel in project drools by kiegroup.
the class PMMLCompilerImpl method getKiePMMLModelsWithSources.
@Override
public List<KiePMMLModel> getKiePMMLModelsWithSources(final String factoryClassName, final String packageName, final InputStream inputStream, final String fileName, final HasClassLoader hasClassloader) {
logger.trace("getModels {} {}", inputStream, hasClassloader);
try {
PMML commonPMMLModel = KiePMMLUtil.load(inputStream, fileName);
Set<String> expectedClasses = commonPMMLModel.getModels().stream().map(model -> {
String modelPackageName = getSanitizedPackageName(String.format(PACKAGE_CLASS_TEMPLATE, packageName, model.getModelName()));
return modelPackageName + "." + getSanitizedClassName(model.getModelName());
}).collect(Collectors.toSet());
final List<KiePMMLModel> toReturn = getModelsWithSources(packageName, commonPMMLModel, hasClassloader);
final Set<String> generatedClasses = new HashSet<>();
Map<String, Boolean> expectedClassModelTypeMap = expectedClasses.stream().collect(Collectors.toMap(expectedClass -> expectedClass, expectedClass -> {
HasSourcesMap retrieved = getHasSourceMap(toReturn, expectedClass);
generatedClasses.addAll(retrieved.getSourcesMap().keySet());
return retrieved.isInterpreted();
}));
if (!generatedClasses.containsAll(expectedClasses)) {
expectedClasses.removeAll(generatedClasses);
String missingClasses = String.join(", ", expectedClasses);
throw new KiePMMLException("Expected generated class " + missingClasses + " not found");
}
Map<String, String> factorySourceMap = getFactorySourceCode(factoryClassName, packageName, expectedClassModelTypeMap);
KiePMMLFactoryModel kiePMMLFactoryModel = new KiePMMLFactoryModel(factoryClassName, packageName, factorySourceMap);
toReturn.add(kiePMMLFactoryModel);
return toReturn;
} catch (KiePMMLInternalException e) {
throw new KiePMMLException("KiePMMLInternalException", e);
} catch (KiePMMLException e) {
throw e;
} catch (Exception e) {
throw new ExternalException("ExternalException", e);
}
}
use of org.kie.pmml.commons.model.KiePMMLModel in project drools by kiegroup.
the class KnowledgeBaseUtils method getModels.
public static List<KiePMMLModel> getModels(final KieBase knowledgeBase) {
List<KiePMMLModel> models = new ArrayList<>();
knowledgeBase.getKiePackages().forEach(kpkg -> {
PMMLPackage pmmlPackage = (PMMLPackage) ((InternalKnowledgePackage) kpkg).getResourceTypePackages().get(ResourceType.PMML);
if (pmmlPackage != null) {
models.addAll(pmmlPackage.getAllModels().values());
}
});
return models;
}
use of org.kie.pmml.commons.model.KiePMMLModel in project drools by kiegroup.
the class KiePMMLSegmentationFactoryTest method getSegmentationSourcesMapCompiled.
@Test
public void getSegmentationSourcesMapCompiled() {
final HasKnowledgeBuilderMock hasKnowledgeBuilderMock = new HasKnowledgeBuilderMock(KNOWLEDGE_BUILDER);
final CommonCompilationDTO<MiningModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, MINING_MODEL, hasKnowledgeBuilderMock);
final MiningModelCompilationDTO compilationDTO = MiningModelCompilationDTO.fromCompilationDTO(source);
final List<KiePMMLModel> nestedModels = new ArrayList<>();
final List<String> expectedGeneratedClasses = MINING_MODEL.getSegmentation().getSegments().stream().map(this::getExpectedNestedModelClass).collect(Collectors.toList());
expectedGeneratedClasses.forEach(expectedGeneratedClass -> {
try {
hasKnowledgeBuilderMock.getClassLoader().loadClass(expectedGeneratedClass);
fail("Expecting class not found: " + expectedGeneratedClass);
} catch (Exception e) {
assertTrue(e instanceof ClassNotFoundException);
}
});
final Map<String, String> retrieved = KiePMMLSegmentationFactory.getSegmentationSourcesMapCompiled(compilationDTO, nestedModels);
assertNotNull(retrieved);
int expectedNestedModels = MINING_MODEL.getSegmentation().getSegments().size();
assertEquals(expectedNestedModels, nestedModels.size());
expectedGeneratedClasses.forEach(expectedGeneratedClass -> {
try {
hasKnowledgeBuilderMock.getClassLoader().loadClass(expectedGeneratedClass);
} catch (Exception e) {
fail("Expecting class to be loaded, but got: " + e.getClass().getName() + " -> " + e.getMessage());
e.printStackTrace();
}
});
}
use of org.kie.pmml.commons.model.KiePMMLModel in project drools by kiegroup.
the class PMMLMiningModelEvaluatorTest method getStep.
@Test
public void getStep() {
final String modelName = "MODEL_NAME";
KiePMMLModel modelMock = mock(KiePMMLModel.class);
when(modelMock.getName()).thenReturn(modelName);
final String segmentName = "SEGMENT_NAME";
KiePMMLSegment segmentMock = mock(KiePMMLSegment.class);
when(segmentMock.getName()).thenReturn(segmentName);
when(segmentMock.getModel()).thenReturn(modelMock);
final String resultObjectName = "RESULT_OBJECT_NAME";
final String resultObjectValue = "RESULT_OBJECT_VALUE";
ResultCode resultCode = OK;
PMML4Result pmml4Result = new PMML4Result();
pmml4Result.setResultCode(resultCode.getName());
pmml4Result.setResultObjectName(resultObjectName);
pmml4Result.getResultVariables().put(resultObjectName, resultObjectValue);
PMMLStep retrieved = evaluator.getStep(segmentMock, pmml4Result);
assertNotNull(retrieved);
assertTrue(retrieved instanceof PMMLMiningModelStep);
Map<String, Object> retrievedInfo = retrieved.getInfo();
assertNotNull(retrievedInfo);
assertEquals(segmentName, retrievedInfo.get("SEGMENT"));
assertEquals(modelName, retrievedInfo.get("MODEL"));
assertEquals(resultCode.getName(), retrievedInfo.get("RESULT CODE"));
assertEquals(resultObjectValue, retrievedInfo.get("RESULT"));
resultCode = FAIL;
pmml4Result = new PMML4Result();
pmml4Result.setResultCode(resultCode.getName());
retrieved = evaluator.getStep(segmentMock, pmml4Result);
assertNotNull(retrieved);
assertTrue(retrieved instanceof PMMLMiningModelStep);
retrievedInfo = retrieved.getInfo();
assertNotNull(retrievedInfo);
assertEquals(segmentName, retrievedInfo.get("SEGMENT"));
assertEquals(modelName, retrievedInfo.get("MODEL"));
assertEquals(resultCode.getName(), retrievedInfo.get("RESULT CODE"));
assertFalse(retrievedInfo.containsKey("RESULT"));
}
use of org.kie.pmml.commons.model.KiePMMLModel in project drools by kiegroup.
the class PMMLMiningModelEvaluatorTest method validateNoKiePMMLMiningModel.
@Test(expected = KiePMMLModelException.class)
public void validateNoKiePMMLMiningModel() {
String name = "NAME";
KiePMMLModel kiePMMLModel = new KiePMMLTestingModel(name, Collections.emptyList());
evaluator.validate(kiePMMLModel);
}
Aggregations