use of org.dmg.pmml.DataField in project drools by kiegroup.
the class ModelUtilsTest method getTargetFieldsWithTargetFieldsWithoutOptType.
@Test
public void getTargetFieldsWithTargetFieldsWithoutOptType() {
final Model model = new RegressionModel();
final DataDictionary dataDictionary = new DataDictionary();
final MiningSchema miningSchema = new MiningSchema();
IntStream.range(0, 3).forEach(i -> {
final String fieldName = "fieldName-" + i;
final DataField dataField = getDataField(fieldName, OpType.CATEGORICAL, DataType.STRING);
dataDictionary.addDataFields(dataField);
final MiningField miningField = getMiningField(fieldName, MiningField.UsageType.PREDICTED);
miningField.setOpType(null);
miningSchema.addMiningFields(miningField);
});
model.setMiningSchema(miningSchema);
List<KiePMMLNameOpType> retrieved = ModelUtils.getTargetFields(getFieldsFromDataDictionary(dataDictionary), model);
assertNotNull(retrieved);
assertEquals(miningSchema.getMiningFields().size(), retrieved.size());
retrieved.forEach(kiePMMLNameOpType -> {
assertTrue(miningSchema.getMiningFields().stream().anyMatch(fld -> kiePMMLNameOpType.getName().equals(fld.getName().getValue())));
Optional<DataField> optionalDataField = dataDictionary.getDataFields().stream().filter(fld -> kiePMMLNameOpType.getName().equals(fld.getName().getValue())).findFirst();
assertTrue(optionalDataField.isPresent());
DataField dataField = optionalDataField.get();
OP_TYPE expected = OP_TYPE.byName(dataField.getOpType().value());
assertEquals(expected, kiePMMLNameOpType.getOpType());
});
}
use of org.dmg.pmml.DataField in project drools by kiegroup.
the class ModelUtilsTest method getOpTypeByDataFieldsNotFound.
@Test(expected = KiePMMLInternalException.class)
public void getOpTypeByDataFieldsNotFound() {
final Model model = new RegressionModel();
final DataDictionary dataDictionary = new DataDictionary();
IntStream.range(0, 3).forEach(i -> {
String fieldName = "field" + i;
final DataField dataField = getRandomDataField();
dataField.setName(FieldName.create(fieldName));
dataDictionary.addDataFields(dataField);
});
ModelUtils.getOpType(getFieldsFromDataDictionary(dataDictionary), model, "NOT_EXISTING");
}
use of org.dmg.pmml.DataField in project drools by kiegroup.
the class ModelUtilsTest method getTargetFieldTypeWithTargetField.
@Test
public void getTargetFieldTypeWithTargetField() {
final String fieldName = "fieldName";
MiningField.UsageType usageType = MiningField.UsageType.PREDICTED;
MiningField miningField = getMiningField(fieldName, usageType);
final DataField dataField = getDataField(fieldName, OpType.CATEGORICAL, DataType.STRING);
final DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
final MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
final Model model = new RegressionModel();
model.setMiningSchema(miningSchema);
DATA_TYPE retrieved = ModelUtils.getTargetFieldType(getFieldsFromDataDictionary(dataDictionary), model);
assertNotNull(retrieved);
assertEquals(DATA_TYPE.STRING, retrieved);
}
use of org.dmg.pmml.DataField in project drools by kiegroup.
the class ModelUtilsTest method getOpTypeByMiningFields.
@Test
public void getOpTypeByMiningFields() {
final Model model = new RegressionModel();
final DataDictionary dataDictionary = new DataDictionary();
final MiningSchema miningSchema = new MiningSchema();
IntStream.range(0, 3).forEach(i -> {
final DataField dataField = getRandomDataField();
dataDictionary.addDataFields(dataField);
final MiningField miningField = getRandomMiningField();
miningField.setName(dataField.getName());
miningSchema.addMiningFields(miningField);
});
model.setMiningSchema(miningSchema);
miningSchema.getMiningFields().forEach(miningField -> {
OP_TYPE retrieved = ModelUtils.getOpType(getFieldsFromDataDictionary(dataDictionary), model, miningField.getName().getValue());
assertNotNull(retrieved);
OP_TYPE expected = OP_TYPE.byName(miningField.getOpType().value());
assertEquals(expected, retrieved);
});
}
use of org.dmg.pmml.DataField in project drools by kiegroup.
the class PMMLModelTestUtils method getRandomMiningModel.
public static MiningModel getRandomMiningModel(DataDictionary dataDictionary) {
MiningModel toReturn = new MiningModel();
List<DataField> dataFields = dataDictionary.getDataFields();
MiningSchema miningSchema = new MiningSchema();
IntStream.range(0, dataFields.size() - 1).forEach(i -> {
DataField dataField = dataFields.get(i);
MiningField miningField = new MiningField();
miningField.setName(dataField.getName());
miningField.setUsageType(MiningField.UsageType.ACTIVE);
miningSchema.addMiningFields(miningField);
});
DataField lastDataField = dataFields.get(dataFields.size() - 1);
MiningField predictedMiningField = new MiningField();
predictedMiningField.setName(lastDataField.getName());
predictedMiningField.setUsageType(MiningField.UsageType.PREDICTED);
miningSchema.addMiningFields(predictedMiningField);
Output output = new Output();
OutputField outputField = new OutputField();
outputField.setName(FieldName.create("OUTPUT_" + lastDataField.getName().getValue()));
outputField.setDataType(lastDataField.getDataType());
outputField.setOpType(getRandomOpType());
toReturn.setModelName(RandomStringUtils.random(6, true, false));
toReturn.setMiningSchema(miningSchema);
toReturn.setOutput(output);
TestModel testModel = getRandomTestModel(dataDictionary);
Segment segment = new Segment();
segment.setModel(testModel);
Segmentation segmentation = new Segmentation();
segmentation.addSegments(segment);
toReturn.setSegmentation(segmentation);
return toReturn;
}
Aggregations