use of org.dmg.pmml.DataField in project drools by kiegroup.
the class KiePMMLAttributeFactoryTest method getAttributeVariableDeclarationWithComplexPartialScore.
@Test
public void getAttributeVariableDeclarationWithComplexPartialScore() throws IOException {
final String variableName = "variableName";
Attribute attribute = new Attribute();
attribute.setReasonCode(REASON_CODE);
Array.Type arrayType = Array.Type.STRING;
List<String> values = getStringObjects(arrayType, 4);
CompoundPredicate compoundPredicate = getCompoundPredicate(values, arrayType);
attribute.setPredicate(compoundPredicate);
attribute.setComplexPartialScore(getComplexPartialScore());
String valuesString = values.stream().map(valueString -> "\"" + valueString + "\"").collect(Collectors.joining(","));
DataDictionary dataDictionary = new DataDictionary();
for (Predicate predicate : compoundPredicate.getPredicates()) {
DataField toAdd = null;
if (predicate instanceof SimplePredicate) {
toAdd = new DataField();
toAdd.setName(((SimplePredicate) predicate).getField());
toAdd.setDataType(DataType.DOUBLE);
} else if (predicate instanceof SimpleSetPredicate) {
toAdd = new DataField();
toAdd.setName(((SimpleSetPredicate) predicate).getField());
toAdd.setDataType(DataType.DOUBLE);
}
if (toAdd != null) {
dataDictionary.addDataFields(toAdd);
}
}
BlockStmt retrieved = KiePMMLAttributeFactory.getAttributeVariableDeclaration(variableName, attribute, getFieldsFromDataDictionary(dataDictionary));
String text = getFileContent(TEST_01_SOURCE);
Statement expected = JavaParserUtils.parseBlock(String.format(text, variableName, valuesString));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
List<Class<?>> imports = Arrays.asList(KiePMMLAttribute.class, KiePMMLComplexPartialScore.class, KiePMMLCompoundPredicate.class, KiePMMLConstant.class, KiePMMLSimplePredicate.class, KiePMMLSimpleSetPredicate.class, Arrays.class, Collections.class);
commonValidateCompilationWithImports(retrieved, imports);
}
use of org.dmg.pmml.DataField in project drools by kiegroup.
the class KiePMMLUtilTest method populateMissingMiningTargetField.
@Test
public void populateMissingMiningTargetField() throws Exception {
final InputStream inputStream = getFileInputStream(NO_TARGET_FIELD_SAMPLE);
final PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(inputStream);
final Model toPopulate = pmml.getModels().get(0);
List<MiningField> miningTargetFields = getMiningTargetFields(toPopulate.getMiningSchema().getMiningFields());
assertTrue(miningTargetFields.isEmpty());
assertNull(toPopulate.getTargets().getTargets().get(0).getField());
KiePMMLUtil.populateMissingMiningTargetField(toPopulate, pmml.getDataDictionary().getDataFields());
miningTargetFields = getMiningTargetFields(toPopulate.getMiningSchema().getMiningFields());
assertEquals(1, miningTargetFields.size());
final MiningField targetField = miningTargetFields.get(0);
assertTrue(pmml.getDataDictionary().getDataFields().stream().anyMatch(dataField -> dataField.getName().equals(targetField.getName())));
assertEquals(targetField.getName(), toPopulate.getTargets().getTargets().get(0).getField());
}
use of org.dmg.pmml.DataField in project drools by kiegroup.
the class KiePMMLUtilTest method getTargetDataField.
@Test
public void getTargetDataField() throws Exception {
final InputStream inputStream = getFileInputStream(NO_TARGET_FIELD_SAMPLE);
final PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(inputStream);
final Model model = pmml.getModels().get(0);
Optional<DataField> optionalDataField = KiePMMLUtil.getTargetDataField(model);
assertTrue(optionalDataField.isPresent());
DataField retrieved = optionalDataField.get();
String expected = String.format(TARGETFIELD_TEMPLATE, "golfing");
assertEquals(expected, retrieved.getName().getValue());
}
use of org.dmg.pmml.DataField in project drools by kiegroup.
the class KiePMMLPredicateInstanceFactoryTest method getKiePMMLPredicate.
@Test
public void getKiePMMLPredicate() {
List<Field<?>> fields = IntStream.range(0, 3).mapToObj(i -> getRandomDataField()).collect(Collectors.toList());
SimplePredicate simplePredicate1 = getRandomSimplePredicate((DataField) fields.get(0));
KiePMMLPredicate retrieved = KiePMMLPredicateInstanceFactory.getKiePMMLPredicate(simplePredicate1, fields);
commonVerifyKiePMMLPredicate(retrieved, simplePredicate1);
SimpleSetPredicate simpleSetPredicate = getRandomSimpleSetPredicate((DataField) fields.get(2));
retrieved = KiePMMLPredicateInstanceFactory.getKiePMMLPredicate(simpleSetPredicate, fields);
commonVerifyKiePMMLPredicate(retrieved, simpleSetPredicate);
final CompoundPredicate compoundPredicate = getRandomCompoundPredicate(fields);
retrieved = KiePMMLPredicateInstanceFactory.getKiePMMLPredicate(compoundPredicate, fields);
commonVerifyKiePMMLPredicate(retrieved, compoundPredicate);
False falsePredicate = new False();
retrieved = KiePMMLPredicateInstanceFactory.getKiePMMLPredicate(falsePredicate, fields);
commonVerifyKiePMMLPredicate(retrieved, falsePredicate);
True truePredicate = new True();
retrieved = KiePMMLPredicateInstanceFactory.getKiePMMLPredicate(truePredicate, fields);
commonVerifyKiePMMLPredicate(retrieved, truePredicate);
}
use of org.dmg.pmml.DataField in project drools by kiegroup.
the class KiePMMLSimplePredicateFactoryTest method getSimplePredicateVariableDeclaration.
@Test
public void getSimplePredicateVariableDeclaration() throws IOException {
String variableName = "variableName";
final SimplePredicate simplePredicate = new SimplePredicate();
simplePredicate.setField(FieldName.create("CUSTOM_FIELD"));
simplePredicate.setValue("235.435");
simplePredicate.setOperator(SimplePredicate.Operator.EQUAL);
String operatorString = OPERATOR.class.getName() + "." + OPERATOR.byName(simplePredicate.getOperator().value());
DataField dataField = new DataField();
dataField.setName(simplePredicate.getField());
dataField.setDataType(DataType.DOUBLE);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
BlockStmt retrieved = KiePMMLSimplePredicateFactory.getSimplePredicateVariableDeclaration(variableName, simplePredicate, getFieldsFromDataDictionary(dataDictionary));
String text = getFileContent(TEST_01_SOURCE);
Statement expected = JavaParserUtils.parseBlock(String.format(text, variableName, simplePredicate.getField().getValue(), operatorString, simplePredicate.getValue()));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
List<Class<?>> imports = Arrays.asList(KiePMMLSimplePredicate.class, Collections.class);
commonValidateCompilationWithImports(retrieved, imports);
}
Aggregations