use of org.kie.pmml.models.regression.model.AbstractKiePMMLTable in project drools by kiegroup.
the class KiePMMLRegressionModelFactory method getRegressionTables.
// not-public KiePMMLRegressionModel instantiation
static Map<String, AbstractKiePMMLTable> getRegressionTables(final RegressionCompilationDTO compilationDTO) {
Map<String, AbstractKiePMMLTable> toReturn = new HashMap<>();
if (compilationDTO.isRegression()) {
final List<RegressionTable> regressionTables = Collections.singletonList(compilationDTO.getModel().getRegressionTables().get(0));
final RegressionCompilationDTO regressionCompilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(compilationDTO, regressionTables, compilationDTO.getModel().getNormalizationMethod());
toReturn.putAll(KiePMMLRegressionTableFactory.getRegressionTables(regressionCompilationDTO));
} else {
final List<RegressionTable> regressionTables = compilationDTO.getModel().getRegressionTables();
final RegressionCompilationDTO regressionCompilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(compilationDTO, regressionTables, RegressionModel.NormalizationMethod.NONE);
KiePMMLClassificationTable kiePMMLClassificationTable = KiePMMLClassificationTableFactory.getClassificationTable(regressionCompilationDTO);
toReturn.put(kiePMMLClassificationTable.getName(), kiePMMLClassificationTable);
}
return toReturn;
}
use of org.kie.pmml.models.regression.model.AbstractKiePMMLTable in project drools by kiegroup.
the class KiePMMLRegressionModelFactory method getKiePMMLRegressionModelClasses.
// KiePMMLRegressionModel instantiation
public static KiePMMLRegressionModel getKiePMMLRegressionModelClasses(final RegressionCompilationDTO compilationDTO) throws IOException, IllegalAccessException, InstantiationException {
logger.trace("getKiePMMLRegressionModelClasses {} {}", compilationDTO.getFields(), compilationDTO.getModel());
Map<String, AbstractKiePMMLTable> regressionTablesMap = getRegressionTables(compilationDTO);
try {
AbstractKiePMMLTable nestedTable = regressionTablesMap.size() == 1 ? regressionTablesMap.values().iterator().next() : regressionTablesMap.values().stream().filter(KiePMMLClassificationTable.class::isInstance).findFirst().orElseThrow(() -> new KiePMMLException("Failed to find expected " + KiePMMLClassificationTable.class.getSimpleName()));
return KiePMMLRegressionModel.builder(compilationDTO.getModelName(), compilationDTO.getMINING_FUNCTION()).withAbstractKiePMMLTable(nestedTable).withTargetField(compilationDTO.getTargetFieldName()).withMiningFields(compilationDTO.getKieMiningFields()).withOutputFields(compilationDTO.getKieOutputFields()).withKiePMMLMiningFields(compilationDTO.getKiePMMLMiningFields()).withKiePMMLOutputFields(compilationDTO.getKiePMMLOutputFields()).withKiePMMLTargets(compilationDTO.getKiePMMLTargetFields()).withKiePMMLTransformationDictionary(compilationDTO.getKiePMMLTransformationDictionary()).withKiePMMLLocalTransformations(compilationDTO.getKiePMMLLocalTransformations()).build();
} catch (Exception e) {
throw new KiePMMLException(e);
}
}
use of org.kie.pmml.models.regression.model.AbstractKiePMMLTable in project drools by kiegroup.
the class KiePMMLRegressionModelFactoryTest method getKiePMMLRegressionModelClasses.
@Test
public void getKiePMMLRegressionModelClasses() throws IOException, IllegalAccessException, InstantiationException {
final CompilationDTO<RegressionModel> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
KiePMMLRegressionModel retrieved = KiePMMLRegressionModelFactory.getKiePMMLRegressionModelClasses(RegressionCompilationDTO.fromCompilationDTO(compilationDTO));
assertNotNull(retrieved);
assertEquals(regressionModel.getModelName(), retrieved.getName());
assertEquals(MINING_FUNCTION.byName(regressionModel.getMiningFunction().value()), retrieved.getMiningFunction());
assertEquals(miningFields.get(0).getName().getValue(), retrieved.getTargetField());
final AbstractKiePMMLTable regressionTable = retrieved.getRegressionTable();
assertNotNull(regressionTable);
assertTrue(regressionTable instanceof KiePMMLClassificationTable);
evaluateCategoricalRegressionTable((KiePMMLClassificationTable) regressionTable);
}
Aggregations