use of org.kie.pmml.commons.model.KiePMMLModel in project drools by kiegroup.
the class PMMLCompilerService method getKiePMMLModelsFromResourceWithSources.
/**
* @param kbuilderImpl
* @param resource
* @return
*/
public static List<KiePMMLModel> getKiePMMLModelsFromResourceWithSources(KnowledgeBuilderImpl kbuilderImpl, Resource resource) {
PMMLCompiler pmmlCompiler = kbuilderImpl.getCachedOrCreate(PMML_COMPILER_CACHE_KEY, PMMLCompilerService::getCompiler);
String[] classNamePackageName = getFactoryClassNamePackageName(resource);
String factoryClassName = classNamePackageName[0];
String packageName = classNamePackageName[1];
try {
final List<KiePMMLModel> toReturn = pmmlCompiler.getKiePMMLModelsWithSources(factoryClassName, packageName, resource.getInputStream(), getFileName(resource.getSourcePath()), new HasKnowledgeBuilderImpl(kbuilderImpl));
populateWithPMMLRuleMappers(toReturn, resource);
return toReturn;
} catch (IOException e) {
throw new ExternalException("ExternalException", e);
}
}
use of org.kie.pmml.commons.model.KiePMMLModel in project drools by kiegroup.
the class PostProcess method updateTargetValueType.
/**
* Verify that the returned value has the required type as defined inside <code>DataDictionary/MiningSchema</code>
* @param model
* @param toUpdate
*/
static void updateTargetValueType(final KiePMMLModel model, final PMML4Result toUpdate) {
DATA_TYPE dataType = model.getMiningFields().stream().filter(miningField -> model.getTargetField().equals(miningField.getName())).map(MiningField::getDataType).findFirst().orElseThrow(() -> new KiePMMLException("Failed to find DATA_TYPE for " + model.getTargetField()));
Object prediction = toUpdate.getResultVariables().get(model.getTargetField());
if (prediction != null) {
Object convertedPrediction = dataType.getActualValue(prediction);
toUpdate.getResultVariables().put(model.getTargetField(), convertedPrediction);
}
}
use of org.kie.pmml.commons.model.KiePMMLModel in project drools by kiegroup.
the class PMMLRuntimeInternalImplTest method getKiePMMLModelMock.
private KiePMMLModel getKiePMMLModelMock() {
KiePMMLModel toReturn = mock(KiePMMLModel.class);
String targetFieldName = "targetFieldName";
MiningField miningFieldMock = mock(MiningField.class);
when(miningFieldMock.getName()).thenReturn(targetFieldName);
when(miningFieldMock.getDataType()).thenReturn(DATA_TYPE.FLOAT);
when(toReturn.getName()).thenReturn(MODEL_NAME);
when(toReturn.getMiningFields()).thenReturn(Collections.singletonList(miningFieldMock));
when(toReturn.getTargetField()).thenReturn(targetFieldName);
when(toReturn.getPmmlMODEL()).thenReturn(PMML_MODEL.TEST_MODEL);
return toReturn;
}
use of org.kie.pmml.commons.model.KiePMMLModel in project drools by kiegroup.
the class KiePMMLSegmentFactoryTest method getSegmentSourcesMapCompiled.
@Test
public void getSegmentSourcesMapCompiled() throws Exception {
final Segment segment = MINING_MODEL.getSegmentation().getSegments().get(0);
final List<KiePMMLModel> nestedModels = new ArrayList<>();
final String expectedNestedModelGeneratedClass = getExpectedNestedModelClass(segment);
final HasKnowledgeBuilderMock hasKnowledgeBuilderMock = new HasKnowledgeBuilderMock(KNOWLEDGE_BUILDER);
try {
hasKnowledgeBuilderMock.getClassLoader().loadClass(expectedNestedModelGeneratedClass);
fail("Expecting class not found: " + expectedNestedModelGeneratedClass);
} catch (Exception e) {
assertTrue(e instanceof ClassNotFoundException);
}
final CommonCompilationDTO<MiningModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, MINING_MODEL, hasKnowledgeBuilderMock);
final MiningModelCompilationDTO compilationDTO = MiningModelCompilationDTO.fromCompilationDTO(source);
final SegmentCompilationDTO segmentCompilationDTO = SegmentCompilationDTO.fromGeneratedPackageNameAndFields(compilationDTO, segment, compilationDTO.getFields());
final Map<String, String> retrieved = KiePMMLSegmentFactory.getSegmentSourcesMapCompiled(segmentCompilationDTO, nestedModels);
commonEvaluateNestedModels(nestedModels);
commonEvaluateMap(retrieved, segment);
hasKnowledgeBuilderMock.getClassLoader().loadClass(expectedNestedModelGeneratedClass);
}
use of org.kie.pmml.commons.model.KiePMMLModel in project drools by kiegroup.
the class KiePMMLSegmentFactoryTest method getSegmentSourcesMap.
@Test
public void getSegmentSourcesMap() {
final Segment segment = MINING_MODEL.getSegmentation().getSegments().get(0);
final List<KiePMMLModel> nestedModels = new ArrayList<>();
final CommonCompilationDTO<MiningModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, MINING_MODEL, new HasKnowledgeBuilderMock(KNOWLEDGE_BUILDER));
final MiningModelCompilationDTO compilationDTO = MiningModelCompilationDTO.fromCompilationDTO(source);
final SegmentCompilationDTO segmentCompilationDTO = SegmentCompilationDTO.fromGeneratedPackageNameAndFields(compilationDTO, segment, compilationDTO.getFields());
final Map<String, String> retrieved = KiePMMLSegmentFactory.getSegmentSourcesMap(segmentCompilationDTO, nestedModels);
commonEvaluateNestedModels(nestedModels);
commonEvaluateMap(retrieved, segment);
}
Aggregations