use of org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock in project drools by kiegroup.
the class KiePMMLMiningModelFactoryTest method setConstructor.
@Test
public void setConstructor() {
PMML_MODEL pmmlModel = PMML_MODEL.byName(MINING_MODEL.getClass().getSimpleName());
final ClassOrInterfaceDeclaration modelTemplate = MODEL_TEMPLATE.clone();
MINING_FUNCTION miningFunction = MINING_FUNCTION.byName(MINING_MODEL.getMiningFunction().value());
final CommonCompilationDTO<MiningModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, MINING_MODEL, new HasClassLoaderMock());
final MiningModelCompilationDTO compilationDTO = MiningModelCompilationDTO.fromCompilationDTO(source);
KiePMMLMiningModelFactory.setConstructor(compilationDTO, modelTemplate);
Map<Integer, Expression> superInvocationExpressionsMap = new HashMap<>();
superInvocationExpressionsMap.put(0, new NameExpr(String.format("\"%s\"", MINING_MODEL.getModelName())));
Map<String, Expression> assignExpressionMap = new HashMap<>();
assignExpressionMap.put("targetField", new StringLiteralExpr(targetFieldName));
assignExpressionMap.put("miningFunction", new NameExpr(miningFunction.getClass().getName() + "." + miningFunction.name()));
assignExpressionMap.put("pmmlMODEL", new NameExpr(pmmlModel.getClass().getName() + "." + pmmlModel.name()));
ClassOrInterfaceType kiePMMLSegmentationClass = parseClassOrInterfaceType(compilationDTO.getSegmentationCanonicalClassName());
ObjectCreationExpr objectCreationExpr = new ObjectCreationExpr();
objectCreationExpr.setType(kiePMMLSegmentationClass);
assignExpressionMap.put("segmentation", objectCreationExpr);
ConstructorDeclaration constructorDeclaration = modelTemplate.getDefaultConstructor().get();
assertTrue(commonEvaluateConstructor(constructorDeclaration, getSanitizedClassName(MINING_MODEL.getModelName()), superInvocationExpressionsMap, assignExpressionMap));
}
use of org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock in project drools by kiegroup.
the class KiePMMLRegressionModelFactoryTest method setStaticGetter.
@Test
public void setStaticGetter() throws IOException {
String nestedTable = "NestedTable";
MINING_FUNCTION miningFunction = MINING_FUNCTION.byName(regressionModel.getMiningFunction().value());
final ClassOrInterfaceDeclaration modelTemplate = MODEL_TEMPLATE.clone();
final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, new ArrayList<>(), regressionModel.getNormalizationMethod());
KiePMMLRegressionModelFactory.setStaticGetter(compilationDTO, modelTemplate, nestedTable);
Map<Integer, Expression> superInvocationExpressionsMap = new HashMap<>();
superInvocationExpressionsMap.put(0, new NameExpr(String.format("\"%s\"", regressionModel.getModelName())));
Map<String, Expression> assignExpressionMap = new HashMap<>();
assignExpressionMap.put("targetField", new StringLiteralExpr(targetMiningField.getName().getValue()));
assignExpressionMap.put("miningFunction", new NameExpr(miningFunction.getClass().getName() + "." + miningFunction.name()));
assignExpressionMap.put("pmmlMODEL", new NameExpr(PMML_MODEL.class.getName() + "." + PMML_MODEL.REGRESSION_MODEL.name()));
MethodCallExpr methodCallExpr = new MethodCallExpr();
methodCallExpr.setScope(new NameExpr(nestedTable));
methodCallExpr.setName(GETKIEPMML_TABLE);
assignExpressionMap.put("regressionTable", methodCallExpr);
MethodDeclaration retrieved = modelTemplate.getMethodsByName(GET_MODEL).get(0);
String text = getFileContent(TEST_01_SOURCE);
MethodDeclaration expected = JavaParserUtils.parseMethod(text);
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
}
use of org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock in project drools by kiegroup.
the class KiePMMLClassificationTableFactoryTest method getClassificationTableBuilders.
@Test
public void getClassificationTableBuilders() {
RegressionTable regressionTableProf = getRegressionTable(3.5, "professional");
RegressionTable regressionTableCler = getRegressionTable(27.4, "clerical");
OutputField outputFieldCat = getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
OutputField outputFieldNum = getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
OutputField outputFieldPrev = getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
String targetField = "targetField";
DataField dataField = new DataField();
dataField.setName(FieldName.create(targetField));
dataField.setOpType(OpType.CATEGORICAL);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
RegressionModel regressionModel = new RegressionModel();
regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
regressionModel.addRegressionTables(regressionTableProf, regressionTableCler);
regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
Output output = new Output();
output.addOutputFields(outputFieldCat, outputFieldNum, outputFieldPrev);
regressionModel.setOutput(output);
MiningField miningField = new MiningField();
miningField.setUsageType(MiningField.UsageType.TARGET);
miningField.setName(dataField.getName());
MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
regressionModel.setMiningSchema(miningSchema);
PMML pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.addModels(regressionModel);
final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
Map<String, KiePMMLTableSourceCategory> retrieved = KiePMMLClassificationTableFactory.getClassificationTableBuilders(compilationDTO);
assertNotNull(retrieved);
assertEquals(3, retrieved.size());
retrieved.values().forEach(kiePMMLTableSourceCategory -> commonValidateKiePMMLRegressionTable(kiePMMLTableSourceCategory.getSource()));
Map<String, String> sources = retrieved.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, stringKiePMMLTableSourceCategoryEntry -> stringKiePMMLTableSourceCategoryEntry.getValue().getSource()));
commonValidateCompilation(sources);
}
use of org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock in project drools by kiegroup.
the class KiePMMLClassificationTableFactoryTest method getClassificationTable.
@Test
public void getClassificationTable() {
RegressionTable regressionTableProf = getRegressionTable(3.5, "professional");
RegressionTable regressionTableCler = getRegressionTable(27.4, "clerical");
OutputField outputFieldCat = getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
OutputField outputFieldNum = getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
OutputField outputFieldPrev = getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
String targetField = "targetField";
DataField dataField = new DataField();
dataField.setName(FieldName.create(targetField));
dataField.setOpType(OpType.CATEGORICAL);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
RegressionModel regressionModel = new RegressionModel();
regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
regressionModel.addRegressionTables(regressionTableProf, regressionTableCler);
regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
Output output = new Output();
output.addOutputFields(outputFieldCat, outputFieldNum, outputFieldPrev);
regressionModel.setOutput(output);
MiningField targetMiningField = new MiningField();
targetMiningField.setUsageType(MiningField.UsageType.TARGET);
targetMiningField.setName(dataField.getName());
MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(targetMiningField);
regressionModel.setMiningSchema(miningSchema);
PMML pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.addModels(regressionModel);
final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
KiePMMLClassificationTable retrieved = KiePMMLClassificationTableFactory.getClassificationTable(compilationDTO);
assertNotNull(retrieved);
assertEquals(regressionModel.getRegressionTables().size(), retrieved.getCategoryTableMap().size());
regressionModel.getRegressionTables().forEach(regressionTable -> assertTrue(retrieved.getCategoryTableMap().containsKey(regressionTable.getTargetCategory().toString())));
assertEquals(regressionModel.getNormalizationMethod().value(), retrieved.getRegressionNormalizationMethod().getName());
assertEquals(OP_TYPE.CATEGORICAL, retrieved.getOpType());
boolean isBinary = regressionModel.getRegressionTables().size() == 2;
assertEquals(isBinary, retrieved.isBinary());
assertEquals(isBinary, retrieved.isBinary());
assertEquals(targetMiningField.getName().getValue(), retrieved.getTargetField());
}
use of org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock in project drools by kiegroup.
the class KiePMMLClassificationTableFactoryTest method setStaticGetter.
@Test
public void setStaticGetter() throws IOException {
String variableName = "variableName";
RegressionTable regressionTableProf = getRegressionTable(3.5, "professional");
RegressionTable regressionTableCler = getRegressionTable(27.4, "clerical");
OutputField outputFieldCat = getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
OutputField outputFieldNum = getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
OutputField outputFieldPrev = getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
String targetField = "targetField";
DataField dataField = new DataField();
dataField.setName(FieldName.create(targetField));
dataField.setOpType(OpType.CATEGORICAL);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
RegressionModel regressionModel = new RegressionModel();
regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
regressionModel.addRegressionTables(regressionTableProf, regressionTableCler);
regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
Output output = new Output();
output.addOutputFields(outputFieldCat, outputFieldNum, outputFieldPrev);
regressionModel.setOutput(output);
MiningField miningField = new MiningField();
miningField.setUsageType(MiningField.UsageType.TARGET);
miningField.setName(dataField.getName());
MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
regressionModel.setMiningSchema(miningSchema);
PMML pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.addModels(regressionModel);
final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
final LinkedHashMap<String, KiePMMLTableSourceCategory> regressionTablesMap = new LinkedHashMap<>();
regressionModel.getRegressionTables().forEach(regressionTable -> {
String key = "defpack." + regressionTable.getTargetCategory().toString().toUpperCase();
KiePMMLTableSourceCategory value = new KiePMMLTableSourceCategory("", regressionTable.getTargetCategory().toString());
regressionTablesMap.put(key, value);
});
final MethodDeclaration staticGetterMethod = STATIC_GETTER_METHOD.clone();
KiePMMLClassificationTableFactory.setStaticGetter(compilationDTO, regressionTablesMap, staticGetterMethod, variableName);
String text = getFileContent(TEST_02_SOURCE);
MethodDeclaration expected = JavaParserUtils.parseMethod(text);
assertTrue(JavaParserUtils.equalsNode(expected, staticGetterMethod));
}
Aggregations