use of org.kie.pmml.models.regression.model.KiePMMLRegressionTable in project drools by kiegroup.
the class KiePMMLRegressionTableFactoryTest method getRegressionTable.
@Test
public void getRegressionTable() {
regressionTable = getRegressionTable(3.5, "professional");
RegressionModel regressionModel = new RegressionModel();
regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
regressionModel.addRegressionTables(regressionTable);
regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
String targetField = "targetField";
DataField dataField = new DataField();
dataField.setName(FieldName.create(targetField));
dataField.setOpType(OpType.CATEGORICAL);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
MiningField miningField = new MiningField();
miningField.setUsageType(MiningField.UsageType.TARGET);
miningField.setName(dataField.getName());
MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
regressionModel.setMiningSchema(miningSchema);
PMML pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.addModels(regressionModel);
final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, new ArrayList<>(), regressionModel.getNormalizationMethod());
KiePMMLRegressionTable retrieved = KiePMMLRegressionTableFactory.getRegressionTable(regressionTable, compilationDTO);
assertNotNull(retrieved);
commonEvaluateRegressionTable(retrieved, regressionTable);
}
use of org.kie.pmml.models.regression.model.KiePMMLRegressionTable in project drools by kiegroup.
the class KiePMMLRegressionTableFactory method getRegressionTables.
// KiePMMLRegressionTable instantiation
public static LinkedHashMap<String, KiePMMLRegressionTable> getRegressionTables(final RegressionCompilationDTO compilationDTO) {
logger.trace("getRegressionTables {}", compilationDTO.getRegressionTables());
LinkedHashMap<String, KiePMMLRegressionTable> toReturn = new LinkedHashMap<>();
for (RegressionTable regressionTable : compilationDTO.getRegressionTables()) {
final KiePMMLRegressionTable kiePMMLRegressionTable = getRegressionTable(regressionTable, compilationDTO);
String targetCategory = regressionTable.getTargetCategory() != null ? regressionTable.getTargetCategory().toString() : "";
toReturn.put(targetCategory, kiePMMLRegressionTable);
}
return toReturn;
}
use of org.kie.pmml.models.regression.model.KiePMMLRegressionTable in project drools by kiegroup.
the class KiePMMLRegressionModelFactoryTest method evaluateCategoricalRegressionTable.
private void evaluateCategoricalRegressionTable(KiePMMLClassificationTable regressionTable) {
assertEquals(REGRESSION_NORMALIZATION_METHOD.byName(regressionModel.getNormalizationMethod().value()), regressionTable.getRegressionNormalizationMethod());
assertEquals(OP_TYPE.CATEGORICAL, regressionTable.getOpType());
final Map<String, KiePMMLRegressionTable> categoryTableMap = regressionTable.getCategoryTableMap();
for (RegressionTable originalRegressionTable : regressionTables) {
assertTrue(categoryTableMap.containsKey(originalRegressionTable.getTargetCategory().toString()));
evaluateRegressionTable(categoryTableMap.get(originalRegressionTable.getTargetCategory().toString()), originalRegressionTable);
}
}
use of org.kie.pmml.models.regression.model.KiePMMLRegressionTable in project drools by kiegroup.
the class KiePMMLRegressionTableFactoryTest method getRegressionTables.
@Test
public void getRegressionTables() {
regressionTable = getRegressionTable(3.5, "professional");
RegressionTable regressionTable2 = getRegressionTable(3.9, "hobby");
RegressionModel regressionModel = new RegressionModel();
regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
regressionModel.addRegressionTables(regressionTable, regressionTable2);
regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
String targetField = "targetField";
DataField dataField = new DataField();
dataField.setName(FieldName.create(targetField));
dataField.setOpType(OpType.CATEGORICAL);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
MiningField miningField = new MiningField();
miningField.setUsageType(MiningField.UsageType.TARGET);
miningField.setName(dataField.getName());
MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
regressionModel.setMiningSchema(miningSchema);
PMML pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.addModels(regressionModel);
final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
Map<String, KiePMMLRegressionTable> retrieved = KiePMMLRegressionTableFactory.getRegressionTables(compilationDTO);
assertNotNull(retrieved);
assertEquals(regressionModel.getRegressionTables().size(), retrieved.size());
regressionModel.getRegressionTables().forEach(regrTabl -> {
assertTrue(retrieved.containsKey(regrTabl.getTargetCategory().toString()));
commonEvaluateRegressionTable(retrieved.get(regrTabl.getTargetCategory().toString()), regrTabl);
});
}
Aggregations