use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class ModelUtilsTest method getTargetFieldName.
@Test
public void getTargetFieldName() {
final String fieldName = "fieldName";
MiningField.UsageType usageType = MiningField.UsageType.ACTIVE;
MiningField miningField = getMiningField(fieldName, usageType);
final DataField dataField = getDataField(fieldName, OpType.CATEGORICAL, DataType.STRING);
final DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
final Model model = new RegressionModel();
model.setMiningSchema(miningSchema);
final List<Field<?>> fields = getFieldsFromDataDictionary(dataDictionary);
Optional<String> retrieved = ModelUtils.getTargetFieldName(fields, model);
assertFalse(retrieved.isPresent());
usageType = MiningField.UsageType.PREDICTED;
miningField = getMiningField(fieldName, usageType);
miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
model.setMiningSchema(miningSchema);
retrieved = ModelUtils.getTargetFieldName(fields, model);
assertTrue(retrieved.isPresent());
assertEquals(fieldName, retrieved.get());
}
use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class PMMLModelTestUtils method getPMMLWithMiningRandomTestModel.
public static PMML getPMMLWithMiningRandomTestModel() {
PMML toReturn = new PMML();
DataDictionary dataDictionary = getRandomDataDictionary();
toReturn.setDataDictionary(dataDictionary);
toReturn.addModels(getRandomMiningModel(dataDictionary));
return toReturn;
}
use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class ModelUtilsTest method getOpTypeByMiningFieldsNotFound.
@Test(expected = KiePMMLInternalException.class)
public void getOpTypeByMiningFieldsNotFound() {
final Model model = new RegressionModel();
final DataDictionary dataDictionary = new DataDictionary();
final MiningSchema miningSchema = new MiningSchema();
IntStream.range(0, 3).forEach(i -> {
String fieldName = "field" + i;
final DataField dataField = getRandomDataField();
dataField.setName(FieldName.create(fieldName));
dataDictionary.addDataFields(dataField);
final MiningField miningField = getRandomMiningField();
miningField.setName(dataField.getName());
miningSchema.addMiningFields(miningField);
});
model.setMiningSchema(miningSchema);
ModelUtils.getOpType(getFieldsFromDataDictionary(dataDictionary), model, "NOT_EXISTING");
}
use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class ModelUtilsTest method getOpTypeByDataFields.
@Test
public void getOpTypeByDataFields() {
final Model model = new RegressionModel();
final DataDictionary dataDictionary = new DataDictionary();
final MiningSchema miningSchema = new MiningSchema();
IntStream.range(0, 3).forEach(i -> {
final DataField dataField = getRandomDataField();
dataDictionary.addDataFields(dataField);
});
model.setMiningSchema(miningSchema);
dataDictionary.getDataFields().forEach(dataField -> {
OP_TYPE retrieved = ModelUtils.getOpType(getFieldsFromDataDictionary(dataDictionary), model, dataField.getName().getValue());
assertNotNull(retrieved);
OP_TYPE expected = OP_TYPE.byName(dataField.getOpType().value());
assertEquals(expected, retrieved);
});
}
use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class KiePMMLClassificationTableFactoryTest method getClassificationTableBuilders.
@Test
public void getClassificationTableBuilders() {
RegressionTable regressionTableProf = getRegressionTable(3.5, "professional");
RegressionTable regressionTableCler = getRegressionTable(27.4, "clerical");
OutputField outputFieldCat = getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
OutputField outputFieldNum = getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
OutputField outputFieldPrev = getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
String targetField = "targetField";
DataField dataField = new DataField();
dataField.setName(FieldName.create(targetField));
dataField.setOpType(OpType.CATEGORICAL);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
RegressionModel regressionModel = new RegressionModel();
regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
regressionModel.addRegressionTables(regressionTableProf, regressionTableCler);
regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
Output output = new Output();
output.addOutputFields(outputFieldCat, outputFieldNum, outputFieldPrev);
regressionModel.setOutput(output);
MiningField miningField = new MiningField();
miningField.setUsageType(MiningField.UsageType.TARGET);
miningField.setName(dataField.getName());
MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
regressionModel.setMiningSchema(miningSchema);
PMML pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.addModels(regressionModel);
final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
Map<String, KiePMMLTableSourceCategory> retrieved = KiePMMLClassificationTableFactory.getClassificationTableBuilders(compilationDTO);
assertNotNull(retrieved);
assertEquals(3, retrieved.size());
retrieved.values().forEach(kiePMMLTableSourceCategory -> commonValidateKiePMMLRegressionTable(kiePMMLTableSourceCategory.getSource()));
Map<String, String> sources = retrieved.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, stringKiePMMLTableSourceCategoryEntry -> stringKiePMMLTableSourceCategoryEntry.getValue().getSource()));
commonValidateCompilation(sources);
}
Aggregations