use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class ModelUtilsTest method getTargetFieldsTypeMapWithTargetFieldsWithoutTargets.
@Test
public void getTargetFieldsTypeMapWithTargetFieldsWithoutTargets() {
final Model model = new RegressionModel();
final DataDictionary dataDictionary = new DataDictionary();
final MiningSchema miningSchema = new MiningSchema();
IntStream.range(0, 3).forEach(i -> {
final DataField dataField = getRandomDataField();
dataDictionary.addDataFields(dataField);
final MiningField miningField = getMiningField(dataField.getName().getValue(), MiningField.UsageType.PREDICTED);
miningSchema.addMiningFields(miningField);
});
model.setMiningSchema(miningSchema);
Map<String, DATA_TYPE> retrieved = ModelUtils.getTargetFieldsTypeMap(getFieldsFromDataDictionary(dataDictionary), model);
assertNotNull(retrieved);
assertEquals(miningSchema.getMiningFields().size(), retrieved.size());
assertTrue(retrieved instanceof LinkedHashMap);
final Iterator<Map.Entry<String, DATA_TYPE>> iterator = retrieved.entrySet().iterator();
for (int i = 0; i < miningSchema.getMiningFields().size(); i++) {
MiningField miningField = miningSchema.getMiningFields().get(i);
DataField dataField = dataDictionary.getDataFields().stream().filter(df -> df.getName().equals(miningField.getName())).findFirst().get();
DATA_TYPE expected = DATA_TYPE.byName(dataField.getDataType().value());
final Map.Entry<String, DATA_TYPE> next = iterator.next();
assertEquals(expected, next.getValue());
}
}
use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class ModelUtilsTest method getTargetFieldTypeWithoutTargetField.
@Test(expected = Exception.class)
public void getTargetFieldTypeWithoutTargetField() {
final String fieldName = "fieldName";
MiningField.UsageType usageType = MiningField.UsageType.ACTIVE;
MiningField miningField = getMiningField(fieldName, usageType);
final DataField dataField = getDataField(fieldName, OpType.CATEGORICAL, DataType.STRING);
final DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
final MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
final Model model = new RegressionModel();
model.setMiningSchema(miningSchema);
ModelUtils.getTargetFieldType(getFieldsFromDataDictionary(dataDictionary), model);
}
use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class ModelUtilsTest method getTargetFieldsWithTargetFieldsWithTargetsWithoutOptType.
@Test
public void getTargetFieldsWithTargetFieldsWithTargetsWithoutOptType() {
final Model model = new RegressionModel();
final DataDictionary dataDictionary = new DataDictionary();
final MiningSchema miningSchema = new MiningSchema();
final Targets targets = new Targets();
IntStream.range(0, 3).forEach(i -> {
final String fieldName = "fieldName-" + i;
final DataField dataField = getDataField(fieldName, OpType.CATEGORICAL, DataType.STRING);
dataDictionary.addDataFields(dataField);
final MiningField miningField = getMiningField(fieldName, MiningField.UsageType.PREDICTED);
miningField.setOpType(OpType.CONTINUOUS);
miningSchema.addMiningFields(miningField);
final Target targetField = getTarget(fieldName, null);
targets.addTargets(targetField);
});
model.setMiningSchema(miningSchema);
model.setTargets(targets);
List<KiePMMLNameOpType> retrieved = ModelUtils.getTargetFields(getFieldsFromDataDictionary(dataDictionary), model);
assertNotNull(retrieved);
assertEquals(miningSchema.getMiningFields().size(), retrieved.size());
retrieved.forEach(kiePMMLNameOpType -> {
Optional<MiningField> optionalMiningField = miningSchema.getMiningFields().stream().filter(fld -> kiePMMLNameOpType.getName().equals(fld.getName().getValue())).findFirst();
assertTrue(optionalMiningField.isPresent());
MiningField miningField = optionalMiningField.get();
OP_TYPE expected = OP_TYPE.byName(miningField.getOpType().value());
assertEquals(expected, kiePMMLNameOpType.getOpType());
});
}
use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class ModelUtilsTest method getDataTypeFromDerivedFieldsAndDataDictionary.
@Test
public void getDataTypeFromDerivedFieldsAndDataDictionary() {
final DataDictionary dataDictionary = new DataDictionary();
IntStream.range(0, 3).forEach(i -> {
final DataField dataField = getRandomDataField();
dataDictionary.addDataFields(dataField);
});
final List<DerivedField> derivedFields = dataDictionary.getDataFields().stream().map(dataField -> {
DerivedField toReturn = new DerivedField();
toReturn.setName(FieldName.create("DER_" + dataField.getName().getValue()));
DataType dataType = getRandomDataType();
while (dataType.equals(dataField.getDataType())) {
dataType = getRandomDataType();
}
toReturn.setDataType(dataType);
return toReturn;
}).collect(Collectors.toList());
final List<Field<?>> fields = new ArrayList<>();
dataDictionary.getDataFields().stream().map(Field.class::cast).forEach(fields::add);
derivedFields.stream().map(Field.class::cast).forEach(fields::add);
dataDictionary.getDataFields().forEach(dataField -> {
String fieldName = dataField.getName().getValue();
DataType retrieved = ModelUtils.getDataType(fields, fieldName);
assertNotNull(retrieved);
DataType expected = dataField.getDataType();
assertEquals(expected, retrieved);
});
derivedFields.forEach(derivedField -> {
String fieldName = derivedField.getName().getValue();
DataType retrieved = ModelUtils.getDataType(fields, fieldName);
assertNotNull(retrieved);
DataType expected = derivedField.getDataType();
assertEquals(expected, retrieved);
});
}
use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class ModelUtilsTest method getTargetFieldsWithTargetFieldsWithOptType.
@Test
public void getTargetFieldsWithTargetFieldsWithOptType() {
final Model model = new RegressionModel();
final DataDictionary dataDictionary = new DataDictionary();
final MiningSchema miningSchema = new MiningSchema();
IntStream.range(0, 3).forEach(i -> {
final String fieldName = "fieldName-" + i;
final DataField dataField = getDataField(fieldName, OpType.CATEGORICAL, DataType.STRING);
dataDictionary.addDataFields(dataField);
final MiningField miningField = getMiningField(fieldName, MiningField.UsageType.PREDICTED);
miningField.setOpType(OpType.CONTINUOUS);
miningSchema.addMiningFields(miningField);
});
model.setMiningSchema(miningSchema);
List<KiePMMLNameOpType> retrieved = ModelUtils.getTargetFields(getFieldsFromDataDictionary(dataDictionary), model);
assertNotNull(retrieved);
assertEquals(miningSchema.getMiningFields().size(), retrieved.size());
retrieved.forEach(kiePMMLNameOpType -> {
Optional<MiningField> optionalMiningField = miningSchema.getMiningFields().stream().filter(fld -> kiePMMLNameOpType.getName().equals(fld.getName().getValue())).findFirst();
assertTrue(optionalMiningField.isPresent());
MiningField miningField = optionalMiningField.get();
OP_TYPE expected = OP_TYPE.byName(miningField.getOpType().value());
assertEquals(expected, kiePMMLNameOpType.getOpType());
});
}
Aggregations