use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class KiePMMLDroolsModelFactoryUtilsTest method getKiePMMLModelCompilationUnit.
@Test
public void getKiePMMLModelCompilationUnit() {
DataDictionary dataDictionary = new DataDictionary();
String targetFieldString = "target.field";
FieldName targetFieldName = FieldName.create(targetFieldString);
dataDictionary.addDataFields(new DataField(targetFieldName, OpType.CONTINUOUS, DataType.DOUBLE));
String modelName = "ModelName";
TreeModel model = new TreeModel();
model.setModelName(modelName);
model.setMiningFunction(MiningFunction.CLASSIFICATION);
MiningField targetMiningField = new MiningField(targetFieldName);
targetMiningField.setUsageType(MiningField.UsageType.TARGET);
MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(targetMiningField);
model.setMiningSchema(miningSchema);
Map<String, KiePMMLOriginalTypeGeneratedType> fieldTypeMap = new HashMap<>();
fieldTypeMap.put(targetFieldString, new KiePMMLOriginalTypeGeneratedType(targetFieldString, getSanitizedClassName(targetFieldString)));
String packageName = "net.test";
PMML pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.addModels(model);
final CommonCompilationDTO<TreeModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(packageName, pmml, model, new HasClassLoaderMock());
final DroolsCompilationDTO<TreeModel> droolsCompilationDTO = DroolsCompilationDTO.fromCompilationDTO(source, fieldTypeMap);
CompilationUnit retrieved = KiePMMLDroolsModelFactoryUtils.getKiePMMLModelCompilationUnit(droolsCompilationDTO, TEMPLATE_SOURCE, TEMPLATE_CLASS_NAME);
assertEquals(droolsCompilationDTO.getPackageName(), retrieved.getPackageDeclaration().get().getNameAsString());
ConstructorDeclaration constructorDeclaration = retrieved.getClassByName(modelName).get().getDefaultConstructor().get();
MINING_FUNCTION miningFunction = MINING_FUNCTION.CLASSIFICATION;
PMML_MODEL pmmlModel = PMML_MODEL.byName(model.getClass().getSimpleName());
Map<String, Expression> assignExpressionMap = new HashMap<>();
assignExpressionMap.put("targetField", new StringLiteralExpr(targetFieldString));
assignExpressionMap.put("miningFunction", new NameExpr(miningFunction.getClass().getName() + "." + miningFunction.name()));
assignExpressionMap.put("pmmlMODEL", new NameExpr(pmmlModel.getClass().getName() + "." + pmmlModel.name()));
String expectedKModulePackageName = getSanitizedPackageName(packageName + "." + modelName);
assignExpressionMap.put("kModulePackageName", new StringLiteralExpr(expectedKModulePackageName));
assertTrue(commonEvaluateAssignExpr(constructorDeclaration.getBody(), assignExpressionMap));
// The last "1" is for
int expectedMethodCallExprs = assignExpressionMap.size() + fieldTypeMap.size() + 1;
// the super invocation
commonEvaluateFieldTypeMap(constructorDeclaration.getBody(), fieldTypeMap, expectedMethodCallExprs);
}
use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class KiePMMLDataDictionaryASTFactoryTest method declareTypes.
@Test
public void declareTypes() {
List<DataField> dataFields = Arrays.asList(getTypeDataField(), getDottedTypeDataField(), getTypeDataField(), getDottedTypeDataField());
DataDictionary dataDictionary = new DataDictionary(dataFields);
final Map<String, KiePMMLOriginalTypeGeneratedType> fieldTypeMap = new HashMap<>();
List<KiePMMLDroolsType> retrieved = KiePMMLDataDictionaryASTFactory.factory(fieldTypeMap).declareTypes(getFieldsFromDataDictionary(dataDictionary));
assertNotNull(retrieved);
assertEquals(dataFields.size(), retrieved.size());
IntStream.range(0, dataFields.size()).forEach(i -> commonVerifyTypeDeclarationDescr(dataFields.get(i), fieldTypeMap, retrieved.get(i)));
}
use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class KiePMMLAttributeFactoryTest method getAttributeVariableDeclarationWithComplexPartialScore.
@Test
public void getAttributeVariableDeclarationWithComplexPartialScore() throws IOException {
final String variableName = "variableName";
Attribute attribute = new Attribute();
attribute.setReasonCode(REASON_CODE);
Array.Type arrayType = Array.Type.STRING;
List<String> values = getStringObjects(arrayType, 4);
CompoundPredicate compoundPredicate = getCompoundPredicate(values, arrayType);
attribute.setPredicate(compoundPredicate);
attribute.setComplexPartialScore(getComplexPartialScore());
String valuesString = values.stream().map(valueString -> "\"" + valueString + "\"").collect(Collectors.joining(","));
DataDictionary dataDictionary = new DataDictionary();
for (Predicate predicate : compoundPredicate.getPredicates()) {
DataField toAdd = null;
if (predicate instanceof SimplePredicate) {
toAdd = new DataField();
toAdd.setName(((SimplePredicate) predicate).getField());
toAdd.setDataType(DataType.DOUBLE);
} else if (predicate instanceof SimpleSetPredicate) {
toAdd = new DataField();
toAdd.setName(((SimpleSetPredicate) predicate).getField());
toAdd.setDataType(DataType.DOUBLE);
}
if (toAdd != null) {
dataDictionary.addDataFields(toAdd);
}
}
BlockStmt retrieved = KiePMMLAttributeFactory.getAttributeVariableDeclaration(variableName, attribute, getFieldsFromDataDictionary(dataDictionary));
String text = getFileContent(TEST_01_SOURCE);
Statement expected = JavaParserUtils.parseBlock(String.format(text, variableName, valuesString));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
List<Class<?>> imports = Arrays.asList(KiePMMLAttribute.class, KiePMMLComplexPartialScore.class, KiePMMLCompoundPredicate.class, KiePMMLConstant.class, KiePMMLSimplePredicate.class, KiePMMLSimpleSetPredicate.class, Arrays.class, Collections.class);
commonValidateCompilationWithImports(retrieved, imports);
}
use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class KiePMMLClassificationTableFactoryTest method getClassificationTableBuilder.
@Test
public void getClassificationTableBuilder() {
RegressionTable regressionTableProf = getRegressionTable(3.5, "professional");
RegressionTable regressionTableCler = getRegressionTable(27.4, "clerical");
OutputField outputFieldCat = getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
OutputField outputFieldNum = getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
OutputField outputFieldPrev = getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
String targetField = "targetField";
DataField dataField = new DataField();
dataField.setName(FieldName.create(targetField));
dataField.setOpType(OpType.CATEGORICAL);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
RegressionModel regressionModel = new RegressionModel();
regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
regressionModel.addRegressionTables(regressionTableProf, regressionTableCler);
regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
Output output = new Output();
output.addOutputFields(outputFieldCat, outputFieldNum, outputFieldPrev);
regressionModel.setOutput(output);
MiningField miningField = new MiningField();
miningField.setUsageType(MiningField.UsageType.TARGET);
miningField.setName(dataField.getName());
MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
regressionModel.setMiningSchema(miningSchema);
PMML pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.addModels(regressionModel);
final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
final LinkedHashMap<String, KiePMMLTableSourceCategory> regressionTablesMap = new LinkedHashMap<>();
regressionModel.getRegressionTables().forEach(regressionTable -> {
String key = compilationDTO.getPackageName() + "." + regressionTable.getTargetCategory().toString().toUpperCase();
KiePMMLTableSourceCategory value = new KiePMMLTableSourceCategory("", regressionTable.getTargetCategory().toString());
regressionTablesMap.put(key, value);
});
Map.Entry<String, String> retrieved = KiePMMLClassificationTableFactory.getClassificationTableBuilder(compilationDTO, regressionTablesMap);
assertNotNull(retrieved);
}
use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class KiePMMLRegressionTableFactoryTest method getRegressionTableBuilders.
@Test
public void getRegressionTableBuilders() {
regressionTable = getRegressionTable(3.5, "professional");
RegressionModel regressionModel = new RegressionModel();
regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
regressionModel.addRegressionTables(regressionTable);
regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
String targetField = "targetField";
DataField dataField = new DataField();
dataField.setName(FieldName.create(targetField));
dataField.setOpType(OpType.CATEGORICAL);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
MiningField miningField = new MiningField();
miningField.setUsageType(MiningField.UsageType.TARGET);
miningField.setName(dataField.getName());
MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
regressionModel.setMiningSchema(miningSchema);
PMML pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.addModels(regressionModel);
final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, new ArrayList<>(), regressionModel.getNormalizationMethod());
Map<String, KiePMMLTableSourceCategory> retrieved = KiePMMLRegressionTableFactory.getRegressionTableBuilders(compilationDTO);
assertNotNull(retrieved);
retrieved.values().forEach(kiePMMLTableSourceCategory -> commonValidateKiePMMLRegressionTable(kiePMMLTableSourceCategory.getSource()));
}
Aggregations