use of org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO in project drools by kiegroup.
the class KiePMMLRegressionTableFactoryTest method getRegressionTable.
@Test
public void getRegressionTable() {
regressionTable = getRegressionTable(3.5, "professional");
RegressionModel regressionModel = new RegressionModel();
regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
regressionModel.addRegressionTables(regressionTable);
regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
String targetField = "targetField";
DataField dataField = new DataField();
dataField.setName(FieldName.create(targetField));
dataField.setOpType(OpType.CATEGORICAL);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
MiningField miningField = new MiningField();
miningField.setUsageType(MiningField.UsageType.TARGET);
miningField.setName(dataField.getName());
MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
regressionModel.setMiningSchema(miningSchema);
PMML pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.addModels(regressionModel);
final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, new ArrayList<>(), regressionModel.getNormalizationMethod());
KiePMMLRegressionTable retrieved = KiePMMLRegressionTableFactory.getRegressionTable(regressionTable, compilationDTO);
assertNotNull(retrieved);
commonEvaluateRegressionTable(retrieved, regressionTable);
}
use of org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO in project drools by kiegroup.
the class KiePMMLRegressionTableFactoryTest method getRegressionTableBuilder.
@Test
public void getRegressionTableBuilder() {
regressionTable = getRegressionTable(3.5, "professional");
RegressionModel regressionModel = new RegressionModel();
regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
regressionModel.addRegressionTables(regressionTable);
regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
String targetField = "targetField";
DataField dataField = new DataField();
dataField.setName(FieldName.create(targetField));
dataField.setOpType(OpType.CATEGORICAL);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
MiningField miningField = new MiningField();
miningField.setUsageType(MiningField.UsageType.TARGET);
miningField.setName(dataField.getName());
MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
regressionModel.setMiningSchema(miningSchema);
PMML pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.addModels(regressionModel);
final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, new ArrayList<>(), regressionModel.getNormalizationMethod());
Map.Entry<String, String> retrieved = KiePMMLRegressionTableFactory.getRegressionTableBuilder(regressionTable, compilationDTO);
assertNotNull(retrieved);
Map<String, String> sources = new HashMap<>();
sources.put(retrieved.getKey(), retrieved.getValue());
commonValidateCompilation(sources);
}
use of org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO in project drools by kiegroup.
the class KiePMMLRegressionModelFactory method getKiePMMLRegressionModelSourcesMap.
// Source code generation
public static Map<String, String> getKiePMMLRegressionModelSourcesMap(final RegressionCompilationDTO compilationDTO) throws IOException {
logger.trace("getKiePMMLRegressionModelSourcesMap {} {} {}", compilationDTO.getFields(), compilationDTO.getModel(), compilationDTO.getPackageName());
String className = compilationDTO.getSimpleClassName();
CompilationUnit cloneCU = JavaParserUtils.getKiePMMLModelCompilationUnit(className, compilationDTO.getPackageName(), KIE_PMML_REGRESSION_MODEL_TEMPLATE_JAVA, KIE_PMML_REGRESSION_MODEL_TEMPLATE);
ClassOrInterfaceDeclaration modelTemplate = cloneCU.getClassByName(className).orElseThrow(() -> new KiePMMLException(MAIN_CLASS_NOT_FOUND + ": " + className));
Map<String, KiePMMLTableSourceCategory> tablesSourceMap = getRegressionTablesMap(compilationDTO);
String nestedTable = tablesSourceMap.size() == 1 ? tablesSourceMap.keySet().iterator().next() : tablesSourceMap.keySet().stream().filter(tableName -> tableName.startsWith(compilationDTO.getPackageName() + ".KiePMMLClassificationTable")).findFirst().orElseThrow(() -> new KiePMMLException("Failed to find expected " + "KiePMMLClassificationTable"));
setStaticGetter(compilationDTO, modelTemplate, nestedTable);
Map<String, String> toReturn = tablesSourceMap.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().getSource()));
toReturn.put(getFullClassName(cloneCU), cloneCU.toString());
return toReturn;
}
use of org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO in project drools by kiegroup.
the class KiePMMLRegressionModelFactory method getRegressionTables.
// not-public KiePMMLRegressionModel instantiation
static Map<String, AbstractKiePMMLTable> getRegressionTables(final RegressionCompilationDTO compilationDTO) {
Map<String, AbstractKiePMMLTable> toReturn = new HashMap<>();
if (compilationDTO.isRegression()) {
final List<RegressionTable> regressionTables = Collections.singletonList(compilationDTO.getModel().getRegressionTables().get(0));
final RegressionCompilationDTO regressionCompilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(compilationDTO, regressionTables, compilationDTO.getModel().getNormalizationMethod());
toReturn.putAll(KiePMMLRegressionTableFactory.getRegressionTables(regressionCompilationDTO));
} else {
final List<RegressionTable> regressionTables = compilationDTO.getModel().getRegressionTables();
final RegressionCompilationDTO regressionCompilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(compilationDTO, regressionTables, RegressionModel.NormalizationMethod.NONE);
KiePMMLClassificationTable kiePMMLClassificationTable = KiePMMLClassificationTableFactory.getClassificationTable(regressionCompilationDTO);
toReturn.put(kiePMMLClassificationTable.getName(), kiePMMLClassificationTable);
}
return toReturn;
}
use of org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO in project drools by kiegroup.
the class KiePMMLClassificationTableFactoryTest method getClassificationTableBuilder.
@Test
public void getClassificationTableBuilder() {
RegressionTable regressionTableProf = getRegressionTable(3.5, "professional");
RegressionTable regressionTableCler = getRegressionTable(27.4, "clerical");
OutputField outputFieldCat = getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
OutputField outputFieldNum = getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
OutputField outputFieldPrev = getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
String targetField = "targetField";
DataField dataField = new DataField();
dataField.setName(FieldName.create(targetField));
dataField.setOpType(OpType.CATEGORICAL);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
RegressionModel regressionModel = new RegressionModel();
regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
regressionModel.addRegressionTables(regressionTableProf, regressionTableCler);
regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
Output output = new Output();
output.addOutputFields(outputFieldCat, outputFieldNum, outputFieldPrev);
regressionModel.setOutput(output);
MiningField miningField = new MiningField();
miningField.setUsageType(MiningField.UsageType.TARGET);
miningField.setName(dataField.getName());
MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
regressionModel.setMiningSchema(miningSchema);
PMML pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.addModels(regressionModel);
final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
final LinkedHashMap<String, KiePMMLTableSourceCategory> regressionTablesMap = new LinkedHashMap<>();
regressionModel.getRegressionTables().forEach(regressionTable -> {
String key = compilationDTO.getPackageName() + "." + regressionTable.getTargetCategory().toString().toUpperCase();
KiePMMLTableSourceCategory value = new KiePMMLTableSourceCategory("", regressionTable.getTargetCategory().toString());
regressionTablesMap.put(key, value);
});
Map.Entry<String, String> retrieved = KiePMMLClassificationTableFactory.getClassificationTableBuilder(compilationDTO, regressionTablesMap);
assertNotNull(retrieved);
}
Aggregations