Search in sources :

Example 11 with KiePMMLModel

use of org.kie.pmml.commons.model.KiePMMLModel in project drools by kiegroup.

the class PMMLCompilerImpl method getKiePMMLModelsWithSources.

@Override
public List<KiePMMLModel> getKiePMMLModelsWithSources(final String factoryClassName, final String packageName, final InputStream inputStream, final String fileName, final HasClassLoader hasClassloader) {
    logger.trace("getModels {} {}", inputStream, hasClassloader);
    try {
        PMML commonPMMLModel = KiePMMLUtil.load(inputStream, fileName);
        Set<String> expectedClasses = commonPMMLModel.getModels().stream().map(model -> {
            String modelPackageName = getSanitizedPackageName(String.format(PACKAGE_CLASS_TEMPLATE, packageName, model.getModelName()));
            return modelPackageName + "." + getSanitizedClassName(model.getModelName());
        }).collect(Collectors.toSet());
        final List<KiePMMLModel> toReturn = getModelsWithSources(packageName, commonPMMLModel, hasClassloader);
        final Set<String> generatedClasses = new HashSet<>();
        Map<String, Boolean> expectedClassModelTypeMap = expectedClasses.stream().collect(Collectors.toMap(expectedClass -> expectedClass, expectedClass -> {
            HasSourcesMap retrieved = getHasSourceMap(toReturn, expectedClass);
            generatedClasses.addAll(retrieved.getSourcesMap().keySet());
            return retrieved.isInterpreted();
        }));
        if (!generatedClasses.containsAll(expectedClasses)) {
            expectedClasses.removeAll(generatedClasses);
            String missingClasses = String.join(", ", expectedClasses);
            throw new KiePMMLException("Expected generated class " + missingClasses + " not found");
        }
        Map<String, String> factorySourceMap = getFactorySourceCode(factoryClassName, packageName, expectedClassModelTypeMap);
        KiePMMLFactoryModel kiePMMLFactoryModel = new KiePMMLFactoryModel(factoryClassName, packageName, factorySourceMap);
        toReturn.add(kiePMMLFactoryModel);
        return toReturn;
    } catch (KiePMMLInternalException e) {
        throw new KiePMMLException("KiePMMLInternalException", e);
    } catch (KiePMMLException e) {
        throw e;
    } catch (Exception e) {
        throw new ExternalException("ExternalException", e);
    }
}
Also used : HasClassLoader(org.kie.pmml.commons.model.HasClassLoader) LoggerFactory(org.slf4j.LoggerFactory) ExternalException(org.kie.pmml.api.exceptions.ExternalException) HashSet(java.util.HashSet) KiePMMLUtil(org.kie.pmml.compiler.commons.utils.KiePMMLUtil) KiePMMLFactoryModel(org.kie.pmml.commons.model.KiePMMLFactoryModel) KiePMMLInternalException(org.kie.pmml.api.exceptions.KiePMMLInternalException) Map(java.util.Map) KiePMMLModel(org.kie.pmml.commons.model.KiePMMLModel) KiePMMLModelUtils.getSanitizedPackageName(org.kie.pmml.commons.utils.KiePMMLModelUtils.getSanitizedPackageName) KiePMMLModelRetriever.getFromCommonDataAndTransformationDictionaryAndModel(org.kie.pmml.compiler.commons.implementations.KiePMMLModelRetriever.getFromCommonDataAndTransformationDictionaryAndModel) PMML(org.dmg.pmml.PMML) Logger(org.slf4j.Logger) Set(java.util.Set) KiePMMLModelRetriever.getFromCommonDataAndTransformationDictionaryAndModelWithSources(org.kie.pmml.compiler.commons.implementations.KiePMMLModelRetriever.getFromCommonDataAndTransformationDictionaryAndModelWithSources) Collectors(java.util.stream.Collectors) List(java.util.List) KiePMMLFactoryFactory.getFactorySourceCode(org.kie.pmml.compiler.commons.factories.KiePMMLFactoryFactory.getFactorySourceCode) HasSourcesMap(org.kie.pmml.commons.model.HasSourcesMap) CommonCompilationDTO(org.kie.pmml.compiler.api.dto.CommonCompilationDTO) Optional(java.util.Optional) PACKAGE_CLASS_TEMPLATE(org.kie.pmml.commons.Constants.PACKAGE_CLASS_TEMPLATE) KiePMMLException(org.kie.pmml.api.exceptions.KiePMMLException) KiePMMLModelUtils.getSanitizedClassName(org.kie.pmml.commons.utils.KiePMMLModelUtils.getSanitizedClassName) InputStream(java.io.InputStream) ExternalException(org.kie.pmml.api.exceptions.ExternalException) ExternalException(org.kie.pmml.api.exceptions.ExternalException) KiePMMLInternalException(org.kie.pmml.api.exceptions.KiePMMLInternalException) KiePMMLException(org.kie.pmml.api.exceptions.KiePMMLException) KiePMMLFactoryModel(org.kie.pmml.commons.model.KiePMMLFactoryModel) KiePMMLModel(org.kie.pmml.commons.model.KiePMMLModel) PMML(org.dmg.pmml.PMML) KiePMMLException(org.kie.pmml.api.exceptions.KiePMMLException) KiePMMLInternalException(org.kie.pmml.api.exceptions.KiePMMLInternalException) HasSourcesMap(org.kie.pmml.commons.model.HasSourcesMap) HashSet(java.util.HashSet)

Example 12 with KiePMMLModel

use of org.kie.pmml.commons.model.KiePMMLModel in project drools by kiegroup.

the class KnowledgeBaseUtils method getModels.

public static List<KiePMMLModel> getModels(final KieBase knowledgeBase) {
    List<KiePMMLModel> models = new ArrayList<>();
    knowledgeBase.getKiePackages().forEach(kpkg -> {
        PMMLPackage pmmlPackage = (PMMLPackage) ((InternalKnowledgePackage) kpkg).getResourceTypePackages().get(ResourceType.PMML);
        if (pmmlPackage != null) {
            models.addAll(pmmlPackage.getAllModels().values());
        }
    });
    return models;
}
Also used : PMMLPackage(org.kie.pmml.evaluator.api.container.PMMLPackage) KiePMMLModel(org.kie.pmml.commons.model.KiePMMLModel) ArrayList(java.util.ArrayList) InternalKnowledgePackage(org.drools.core.definitions.InternalKnowledgePackage)

Example 13 with KiePMMLModel

use of org.kie.pmml.commons.model.KiePMMLModel in project drools by kiegroup.

the class KiePMMLSegmentationFactoryTest method getSegmentationSourcesMapCompiled.

@Test
public void getSegmentationSourcesMapCompiled() {
    final HasKnowledgeBuilderMock hasKnowledgeBuilderMock = new HasKnowledgeBuilderMock(KNOWLEDGE_BUILDER);
    final CommonCompilationDTO<MiningModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, MINING_MODEL, hasKnowledgeBuilderMock);
    final MiningModelCompilationDTO compilationDTO = MiningModelCompilationDTO.fromCompilationDTO(source);
    final List<KiePMMLModel> nestedModels = new ArrayList<>();
    final List<String> expectedGeneratedClasses = MINING_MODEL.getSegmentation().getSegments().stream().map(this::getExpectedNestedModelClass).collect(Collectors.toList());
    expectedGeneratedClasses.forEach(expectedGeneratedClass -> {
        try {
            hasKnowledgeBuilderMock.getClassLoader().loadClass(expectedGeneratedClass);
            fail("Expecting class not found: " + expectedGeneratedClass);
        } catch (Exception e) {
            assertTrue(e instanceof ClassNotFoundException);
        }
    });
    final Map<String, String> retrieved = KiePMMLSegmentationFactory.getSegmentationSourcesMapCompiled(compilationDTO, nestedModels);
    assertNotNull(retrieved);
    int expectedNestedModels = MINING_MODEL.getSegmentation().getSegments().size();
    assertEquals(expectedNestedModels, nestedModels.size());
    expectedGeneratedClasses.forEach(expectedGeneratedClass -> {
        try {
            hasKnowledgeBuilderMock.getClassLoader().loadClass(expectedGeneratedClass);
        } catch (Exception e) {
            fail("Expecting class to be loaded, but got: " + e.getClass().getName() + " -> " + e.getMessage());
            e.printStackTrace();
        }
    });
}
Also used : HasKnowledgeBuilderMock(org.kie.pmml.models.mining.compiler.HasKnowledgeBuilderMock) ArrayList(java.util.ArrayList) IOException(java.io.IOException) JAXBException(javax.xml.bind.JAXBException) SAXException(org.xml.sax.SAXException) MiningModel(org.dmg.pmml.mining.MiningModel) KiePMMLModel(org.kie.pmml.commons.model.KiePMMLModel) MiningModelCompilationDTO(org.kie.pmml.models.mining.compiler.dto.MiningModelCompilationDTO) Test(org.junit.Test)

Example 14 with KiePMMLModel

use of org.kie.pmml.commons.model.KiePMMLModel in project drools by kiegroup.

the class PMMLMiningModelEvaluatorTest method getStep.

@Test
public void getStep() {
    final String modelName = "MODEL_NAME";
    KiePMMLModel modelMock = mock(KiePMMLModel.class);
    when(modelMock.getName()).thenReturn(modelName);
    final String segmentName = "SEGMENT_NAME";
    KiePMMLSegment segmentMock = mock(KiePMMLSegment.class);
    when(segmentMock.getName()).thenReturn(segmentName);
    when(segmentMock.getModel()).thenReturn(modelMock);
    final String resultObjectName = "RESULT_OBJECT_NAME";
    final String resultObjectValue = "RESULT_OBJECT_VALUE";
    ResultCode resultCode = OK;
    PMML4Result pmml4Result = new PMML4Result();
    pmml4Result.setResultCode(resultCode.getName());
    pmml4Result.setResultObjectName(resultObjectName);
    pmml4Result.getResultVariables().put(resultObjectName, resultObjectValue);
    PMMLStep retrieved = evaluator.getStep(segmentMock, pmml4Result);
    assertNotNull(retrieved);
    assertTrue(retrieved instanceof PMMLMiningModelStep);
    Map<String, Object> retrievedInfo = retrieved.getInfo();
    assertNotNull(retrievedInfo);
    assertEquals(segmentName, retrievedInfo.get("SEGMENT"));
    assertEquals(modelName, retrievedInfo.get("MODEL"));
    assertEquals(resultCode.getName(), retrievedInfo.get("RESULT CODE"));
    assertEquals(resultObjectValue, retrievedInfo.get("RESULT"));
    resultCode = FAIL;
    pmml4Result = new PMML4Result();
    pmml4Result.setResultCode(resultCode.getName());
    retrieved = evaluator.getStep(segmentMock, pmml4Result);
    assertNotNull(retrieved);
    assertTrue(retrieved instanceof PMMLMiningModelStep);
    retrievedInfo = retrieved.getInfo();
    assertNotNull(retrievedInfo);
    assertEquals(segmentName, retrievedInfo.get("SEGMENT"));
    assertEquals(modelName, retrievedInfo.get("MODEL"));
    assertEquals(resultCode.getName(), retrievedInfo.get("RESULT CODE"));
    assertFalse(retrievedInfo.containsKey("RESULT"));
}
Also used : PMML4Result(org.kie.api.pmml.PMML4Result) PMMLStep(org.kie.pmml.api.models.PMMLStep) KiePMMLModel(org.kie.pmml.commons.model.KiePMMLModel) KiePMMLSegment(org.kie.pmml.models.mining.model.segmentation.KiePMMLSegment) ResultCode(org.kie.pmml.api.enums.ResultCode) PMMLContextTest(org.kie.pmml.commons.testingutility.PMMLContextTest) Test(org.junit.Test)

Example 15 with KiePMMLModel

use of org.kie.pmml.commons.model.KiePMMLModel in project drools by kiegroup.

the class PMMLMiningModelEvaluatorTest method validateNoKiePMMLMiningModel.

@Test(expected = KiePMMLModelException.class)
public void validateNoKiePMMLMiningModel() {
    String name = "NAME";
    KiePMMLModel kiePMMLModel = new KiePMMLTestingModel(name, Collections.emptyList());
    evaluator.validate(kiePMMLModel);
}
Also used : KiePMMLModel(org.kie.pmml.commons.model.KiePMMLModel) KiePMMLTestingModel(org.kie.pmml.commons.testingutility.KiePMMLTestingModel) PMMLContextTest(org.kie.pmml.commons.testingutility.PMMLContextTest) Test(org.junit.Test)

Aggregations

KiePMMLModel (org.kie.pmml.commons.model.KiePMMLModel)37 Test (org.junit.Test)23 ArrayList (java.util.ArrayList)17 MiningModel (org.dmg.pmml.mining.MiningModel)9 KiePMMLException (org.kie.pmml.api.exceptions.KiePMMLException)9 KiePMMLTestingModel (org.kie.pmml.commons.testingutility.KiePMMLTestingModel)8 CommonCompilationDTO (org.kie.pmml.compiler.api.dto.CommonCompilationDTO)7 HasClassLoaderMock (org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock)7 List (java.util.List)6 Collectors (java.util.stream.Collectors)6 HasKnowledgeBuilderMock (org.kie.pmml.models.mining.compiler.HasKnowledgeBuilderMock)6 MiningModelCompilationDTO (org.kie.pmml.models.mining.compiler.dto.MiningModelCompilationDTO)6 HasSourcesMap (org.kie.pmml.commons.model.HasSourcesMap)5 File (java.io.File)4 IOException (java.io.IOException)4 HashMap (java.util.HashMap)4 InternalKnowledgePackage (org.drools.core.definitions.InternalKnowledgePackage)4 Assert.assertEquals (org.junit.Assert.assertEquals)4 Assert.assertTrue (org.junit.Assert.assertTrue)4 MINING_FUNCTION (org.kie.pmml.api.enums.MINING_FUNCTION)4