use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class KiePMMLCompoundPredicateFactoryTest method getCompoundPredicateVariableDeclaration.
@Test
public void getCompoundPredicateVariableDeclaration() throws IOException {
String variableName = "variableName";
SimplePredicate simplePredicate1 = getSimplePredicate(PARAM_1, value1, operator1);
SimplePredicate simplePredicate2 = getSimplePredicate(PARAM_2, value2, operator2);
Array.Type arrayType = Array.Type.STRING;
List<String> values = getStringObjects(arrayType, 4);
SimpleSetPredicate simpleSetPredicate = getSimpleSetPredicate(values, arrayType, SimpleSetPredicate.BooleanOperator.IS_IN);
CompoundPredicate compoundPredicate = new CompoundPredicate();
compoundPredicate.setBooleanOperator(CompoundPredicate.BooleanOperator.AND);
compoundPredicate.getPredicates().add(0, simplePredicate1);
compoundPredicate.getPredicates().add(1, simplePredicate2);
compoundPredicate.getPredicates().add(2, simpleSetPredicate);
DataField dataField1 = new DataField();
dataField1.setName(simplePredicate1.getField());
dataField1.setDataType(DataType.DOUBLE);
DataField dataField2 = new DataField();
dataField2.setName(simplePredicate2.getField());
dataField2.setDataType(DataType.DOUBLE);
DataField dataField3 = new DataField();
dataField3.setName(simpleSetPredicate.getField());
dataField3.setDataType(DataType.DOUBLE);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField1, dataField2, dataField3);
String booleanOperatorString = BOOLEAN_OPERATOR.class.getName() + "." + BOOLEAN_OPERATOR.byName(compoundPredicate.getBooleanOperator().value()).name();
String valuesString = values.stream().map(valueString -> "\"" + valueString + "\"").collect(Collectors.joining(","));
final List<Field<?>> fields = getFieldsFromDataDictionary(dataDictionary);
BlockStmt retrieved = KiePMMLCompoundPredicateFactory.getCompoundPredicateVariableDeclaration(variableName, compoundPredicate, fields);
String text = getFileContent(TEST_01_SOURCE);
Statement expected = JavaParserUtils.parseBlock(String.format(text, variableName, valuesString, booleanOperatorString));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
List<Class<?>> imports = Arrays.asList(KiePMMLCompoundPredicate.class, KiePMMLSimplePredicate.class, KiePMMLSimpleSetPredicate.class, Arrays.class, Collections.class);
commonValidateCompilationWithImports(retrieved, imports);
}
use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class KiePMMLSimplePredicateFactoryTest method getSimplePredicateVariableDeclaration.
@Test
public void getSimplePredicateVariableDeclaration() throws IOException {
String variableName = "variableName";
final SimplePredicate simplePredicate = new SimplePredicate();
simplePredicate.setField(FieldName.create("CUSTOM_FIELD"));
simplePredicate.setValue("235.435");
simplePredicate.setOperator(SimplePredicate.Operator.EQUAL);
String operatorString = OPERATOR.class.getName() + "." + OPERATOR.byName(simplePredicate.getOperator().value());
DataField dataField = new DataField();
dataField.setName(simplePredicate.getField());
dataField.setDataType(DataType.DOUBLE);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
BlockStmt retrieved = KiePMMLSimplePredicateFactory.getSimplePredicateVariableDeclaration(variableName, simplePredicate, getFieldsFromDataDictionary(dataDictionary));
String text = getFileContent(TEST_01_SOURCE);
Statement expected = JavaParserUtils.parseBlock(String.format(text, variableName, simplePredicate.getField().getValue(), operatorString, simplePredicate.getValue()));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
List<Class<?>> imports = Arrays.asList(KiePMMLSimplePredicate.class, Collections.class);
commonValidateCompilationWithImports(retrieved, imports);
}
use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class ModelUtilsTest method getOpTypeByTargetsNotFound.
@Test(expected = KiePMMLInternalException.class)
public void getOpTypeByTargetsNotFound() {
final Model model = new RegressionModel();
final DataDictionary dataDictionary = new DataDictionary();
final MiningSchema miningSchema = new MiningSchema();
final Targets targets = new Targets();
IntStream.range(0, 3).forEach(i -> {
String fieldName = "field" + i;
final DataField dataField = getRandomDataField();
dataField.setName(FieldName.create(fieldName));
dataDictionary.addDataFields(dataField);
final MiningField miningField = getRandomMiningField();
miningField.setName(dataField.getName());
miningSchema.addMiningFields(miningField);
final Target targetField = getRandomTarget();
targetField.setField(dataField.getName());
targets.addTargets(targetField);
});
model.setMiningSchema(miningSchema);
model.setTargets(targets);
ModelUtils.getOpType(getFieldsFromDataDictionary(dataDictionary), model, "NOT_EXISTING");
}
use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class ModelUtilsTest method getTargetFieldsWithTargetFieldsWithoutOptType.
@Test
public void getTargetFieldsWithTargetFieldsWithoutOptType() {
final Model model = new RegressionModel();
final DataDictionary dataDictionary = new DataDictionary();
final MiningSchema miningSchema = new MiningSchema();
IntStream.range(0, 3).forEach(i -> {
final String fieldName = "fieldName-" + i;
final DataField dataField = getDataField(fieldName, OpType.CATEGORICAL, DataType.STRING);
dataDictionary.addDataFields(dataField);
final MiningField miningField = getMiningField(fieldName, MiningField.UsageType.PREDICTED);
miningField.setOpType(null);
miningSchema.addMiningFields(miningField);
});
model.setMiningSchema(miningSchema);
List<KiePMMLNameOpType> retrieved = ModelUtils.getTargetFields(getFieldsFromDataDictionary(dataDictionary), model);
assertNotNull(retrieved);
assertEquals(miningSchema.getMiningFields().size(), retrieved.size());
retrieved.forEach(kiePMMLNameOpType -> {
assertTrue(miningSchema.getMiningFields().stream().anyMatch(fld -> kiePMMLNameOpType.getName().equals(fld.getName().getValue())));
Optional<DataField> optionalDataField = dataDictionary.getDataFields().stream().filter(fld -> kiePMMLNameOpType.getName().equals(fld.getName().getValue())).findFirst();
assertTrue(optionalDataField.isPresent());
DataField dataField = optionalDataField.get();
OP_TYPE expected = OP_TYPE.byName(dataField.getOpType().value());
assertEquals(expected, kiePMMLNameOpType.getOpType());
});
}
use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class ModelUtilsTest method getOpTypeByDataFieldsNotFound.
@Test(expected = KiePMMLInternalException.class)
public void getOpTypeByDataFieldsNotFound() {
final Model model = new RegressionModel();
final DataDictionary dataDictionary = new DataDictionary();
IntStream.range(0, 3).forEach(i -> {
String fieldName = "field" + i;
final DataField dataField = getRandomDataField();
dataField.setName(FieldName.create(fieldName));
dataDictionary.addDataFields(dataField);
});
ModelUtils.getOpType(getFieldsFromDataDictionary(dataDictionary), model, "NOT_EXISTING");
}
Aggregations