use of org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory in project drools by kiegroup.
the class KiePMMLRegressionTableFactoryTest method getRegressionTableBuilders.
@Test
public void getRegressionTableBuilders() {
regressionTable = getRegressionTable(3.5, "professional");
RegressionModel regressionModel = new RegressionModel();
regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
regressionModel.addRegressionTables(regressionTable);
regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
String targetField = "targetField";
DataField dataField = new DataField();
dataField.setName(FieldName.create(targetField));
dataField.setOpType(OpType.CATEGORICAL);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
MiningField miningField = new MiningField();
miningField.setUsageType(MiningField.UsageType.TARGET);
miningField.setName(dataField.getName());
MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
regressionModel.setMiningSchema(miningSchema);
PMML pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.addModels(regressionModel);
final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, new ArrayList<>(), regressionModel.getNormalizationMethod());
Map<String, KiePMMLTableSourceCategory> retrieved = KiePMMLRegressionTableFactory.getRegressionTableBuilders(compilationDTO);
assertNotNull(retrieved);
retrieved.values().forEach(kiePMMLTableSourceCategory -> commonValidateKiePMMLRegressionTable(kiePMMLTableSourceCategory.getSource()));
}
use of org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory in project drools by kiegroup.
the class KiePMMLClassificationTableFactory method getClassificationTableBuilders.
// Source code generation
public static Map<String, KiePMMLTableSourceCategory> getClassificationTableBuilders(final RegressionCompilationDTO compilationDTO) {
logger.trace("getRegressionTables {}", compilationDTO.getRegressionTables());
LinkedHashMap<String, KiePMMLTableSourceCategory> toReturn = KiePMMLRegressionTableFactory.getRegressionTableBuilders(compilationDTO);
Map.Entry<String, String> regressionTableEntry = getClassificationTableBuilder(compilationDTO, toReturn);
toReturn.put(regressionTableEntry.getKey(), new KiePMMLTableSourceCategory(regressionTableEntry.getValue(), ""));
return toReturn;
}
use of org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory in project drools by kiegroup.
the class KiePMMLRegressionTableFactory method getRegressionTableBuilders.
// Source code generation
public static LinkedHashMap<String, KiePMMLTableSourceCategory> getRegressionTableBuilders(final RegressionCompilationDTO compilationDTO) {
logger.trace("getRegressionTables {}", compilationDTO.getRegressionTables());
LinkedHashMap<String, KiePMMLTableSourceCategory> toReturn = new LinkedHashMap<>();
for (RegressionTable regressionTable : compilationDTO.getRegressionTables()) {
final Map.Entry<String, String> regressionTableEntry = getRegressionTableBuilder(regressionTable, compilationDTO);
String targetCategory = regressionTable.getTargetCategory() != null ? regressionTable.getTargetCategory().toString() : "";
toReturn.put(regressionTableEntry.getKey(), new KiePMMLTableSourceCategory(regressionTableEntry.getValue(), targetCategory));
}
return toReturn;
}
use of org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory in project drools by kiegroup.
the class KiePMMLRegressionModelFactoryTest method getRegressionTablesMap.
@Test
public void getRegressionTablesMap() {
final CompilationDTO<RegressionModel> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
Map<String, KiePMMLTableSourceCategory> retrieved = KiePMMLRegressionModelFactory.getRegressionTablesMap(RegressionCompilationDTO.fromCompilationDTO(compilationDTO));
// One for classification
int expectedSize = regressionTables.size() + 1;
assertEquals(expectedSize, retrieved.size());
final Collection<KiePMMLTableSourceCategory> values = retrieved.values();
regressionTables.forEach(regressionTable -> assertTrue(values.stream().anyMatch(kiePMMLTableSourceCategory -> kiePMMLTableSourceCategory.getCategory().equals(regressionTable.getTargetCategory()))));
}
Aggregations