use of org.dmg.pmml.DataField in project drools by kiegroup.
the class KiePMMLUtil method populateMissingOutputFieldDataType.
/**
* Method to populate the <b>dataType</b> property of <code>OutputField</code>s.
* Such property was optional until 4.4.1 spec
* @param toPopulate
* @param miningFields
* @param dataFields
*/
static void populateMissingOutputFieldDataType(List<OutputField> toPopulate, List<MiningField> miningFields, List<DataField> dataFields) {
// partial implementation to fix missing "dataType" inside OutputField; "dataType" became mandatory only in 4.4.1 version
List<MiningField> targetFields = getMiningTargetFields(miningFields);
toPopulate.stream().filter(outputField -> outputField.getDataType() == null).forEach(outputField -> {
MiningField referencedField = null;
if (outputField.getTargetField() != null) {
referencedField = targetFields.stream().filter(targetField -> outputField.getTargetField().equals(targetField.getName())).findFirst().orElseThrow(() -> new KiePMMLException("Failed to find a target field for OutputField " + outputField.getName().getValue()));
}
if (referencedField == null && (outputField.getResultFeature() == null || outputField.getResultFeature().equals(ResultFeature.PREDICTED_VALUE))) {
// default predictedValue
referencedField = targetFields.stream().findFirst().orElse(// It is allowed to not have any "target" field inside MiningSchema
null);
}
if (referencedField == null && ResultFeature.PROBABILITY.equals(outputField.getResultFeature())) {
// we set the "dataType" to "double" because outputField is a "probability", we may return
outputField.setDataType(DataType.DOUBLE);
return;
}
if (referencedField != null) {
FieldName targetFieldName = referencedField.getName();
DataField dataField = dataFields.stream().filter(df -> df.getName().equals(targetFieldName)).findFirst().orElseThrow(() -> new KiePMMLException("Failed to find a DataField field for " + "MiningField " + targetFieldName.toString()));
outputField.setDataType(dataField.getDataType());
}
});
}
use of org.dmg.pmml.DataField in project drools by kiegroup.
the class KiePMMLCompoundPredicateFactoryTest method getCompoundPredicateVariableDeclaration.
@Test
public void getCompoundPredicateVariableDeclaration() throws IOException {
String variableName = "variableName";
SimplePredicate simplePredicate1 = getSimplePredicate(PARAM_1, value1, operator1);
SimplePredicate simplePredicate2 = getSimplePredicate(PARAM_2, value2, operator2);
Array.Type arrayType = Array.Type.STRING;
List<String> values = getStringObjects(arrayType, 4);
SimpleSetPredicate simpleSetPredicate = getSimpleSetPredicate(values, arrayType, SimpleSetPredicate.BooleanOperator.IS_IN);
CompoundPredicate compoundPredicate = new CompoundPredicate();
compoundPredicate.setBooleanOperator(CompoundPredicate.BooleanOperator.AND);
compoundPredicate.getPredicates().add(0, simplePredicate1);
compoundPredicate.getPredicates().add(1, simplePredicate2);
compoundPredicate.getPredicates().add(2, simpleSetPredicate);
DataField dataField1 = new DataField();
dataField1.setName(simplePredicate1.getField());
dataField1.setDataType(DataType.DOUBLE);
DataField dataField2 = new DataField();
dataField2.setName(simplePredicate2.getField());
dataField2.setDataType(DataType.DOUBLE);
DataField dataField3 = new DataField();
dataField3.setName(simpleSetPredicate.getField());
dataField3.setDataType(DataType.DOUBLE);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField1, dataField2, dataField3);
String booleanOperatorString = BOOLEAN_OPERATOR.class.getName() + "." + BOOLEAN_OPERATOR.byName(compoundPredicate.getBooleanOperator().value()).name();
String valuesString = values.stream().map(valueString -> "\"" + valueString + "\"").collect(Collectors.joining(","));
final List<Field<?>> fields = getFieldsFromDataDictionary(dataDictionary);
BlockStmt retrieved = KiePMMLCompoundPredicateFactory.getCompoundPredicateVariableDeclaration(variableName, compoundPredicate, fields);
String text = getFileContent(TEST_01_SOURCE);
Statement expected = JavaParserUtils.parseBlock(String.format(text, variableName, valuesString, booleanOperatorString));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
List<Class<?>> imports = Arrays.asList(KiePMMLCompoundPredicate.class, KiePMMLSimplePredicate.class, KiePMMLSimpleSetPredicate.class, Arrays.class, Collections.class);
commonValidateCompilationWithImports(retrieved, imports);
}
use of org.dmg.pmml.DataField in project openscoring by openscoring.
the class ModelUtil method encodeInputFields.
private static List<Field> encodeInputFields(List<InputField> inputFields) {
Function<InputField, Field> function = new Function<InputField, Field>() {
@Override
public Field apply(InputField inputField) {
FieldName name = inputField.getName();
DataField dataField = (DataField) inputField.getField();
Field field = new Field(name.getValue());
field.setName(dataField.getDisplayName());
field.setDataType(inputField.getDataType());
field.setOpType(inputField.getOpType());
field.setValues(encodeValues(dataField));
return field;
}
};
List<Field> fields = new ArrayList<>(Lists.transform(inputFields, function));
return fields;
}
use of org.dmg.pmml.DataField in project shifu by ShifuML.
the class PMMLAdapterCommonUtil method getDataDicHeaders.
/**
* get the header names from the PMML data dictionary
*
* @param pmml
* the pmml model
* @return headers
*/
public static String[] getDataDicHeaders(final PMML pmml) {
DataDictionary dictionary = pmml.getDataDictionary();
List<DataField> fields = dictionary.getDataFields();
int len = fields.size();
String[] headers = new String[len];
for (int i = 0; i < len; i++) {
headers[i] = fields.get(i).getName().getValue();
}
return headers;
}
use of org.dmg.pmml.DataField in project drools by kiegroup.
the class DMNImportPMMLInfo method from.
public static Either<Exception, DMNImportPMMLInfo> from(InputStream is, DMNCompilerConfigurationImpl cc, DMNModelImpl model, Import i) {
try {
final PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
PMMLHeaderInfo h = PMMLInfo.pmmlToHeaderInfo(pmml, pmml.getHeader());
for (DataField df : pmml.getDataDictionary().getDataFields()) {
String dfName = df.getName().getValue();
BuiltInType ft = getBuiltInTypeByDataType(df.getDataType());
List<FEELProfile> helperFEELProfiles = cc.getFeelProfiles();
DMNFEELHelper feel = new DMNFEELHelper(cc.getRootClassLoader(), helperFEELProfiles);
List<UnaryTest> av = new ArrayList<>();
if (df.getValues() != null && !df.getValues().isEmpty() && ft != BuiltInType.UNKNOWN) {
final BuiltInType feelType = ft;
String lov = df.getValues().stream().map(Value::getValue).map(o -> feelType == BuiltInType.STRING ? "\"" + o.toString() + "\"" : o.toString()).collect(Collectors.joining(","));
av = feel.evaluateUnaryTests(lov, Collections.emptyMap());
} else if (df.getIntervals() != null && !df.getIntervals().isEmpty() && ft != BuiltInType.UNKNOWN) {
for (Interval interval : df.getIntervals()) {
String utString = null;
switch(interval.getClosure()) {
case CLOSED_CLOSED:
utString = new StringBuilder("[").append(interval.getLeftMargin()).append("..").append(interval.getRightMargin()).append("]").toString();
break;
case CLOSED_OPEN:
utString = new StringBuilder("[").append(interval.getLeftMargin()).append("..").append(interval.getRightMargin()).append(")").toString();
break;
case OPEN_CLOSED:
utString = new StringBuilder("(").append(interval.getLeftMargin()).append("..").append(interval.getRightMargin()).append("]").toString();
break;
case OPEN_OPEN:
utString = new StringBuilder("(").append(interval.getLeftMargin()).append("..").append(interval.getRightMargin()).append(")").toString();
break;
}
List<UnaryTest> ut = feel.evaluateUnaryTests(utString, Collections.emptyMap());
av.addAll(ut);
}
}
DMNType type = new SimpleTypeImpl(i.getNamespace(), dfName, null, false, av, model.getTypeRegistry().resolveType(model.getDefinitions().getURIFEEL(), ft.getName()), ft);
model.getTypeRegistry().registerType(type);
}
pmml.getModels().stream().forEach(m -> registerOutputFieldType(m, model, i));
List<DMNPMMLModelInfo> models = pmml.getModels().stream().map(m -> PMMLInfo.pmmlToModelInfo(m)).map(proto -> DMNPMMLModelInfo.from(proto, model, i)).collect(Collectors.toList());
DMNImportPMMLInfo info = new DMNImportPMMLInfo(i, models, h);
return Either.ofRight(info);
} catch (Throwable e) {
return Either.ofLeft(new Exception("Unable to process DMNImportPMMLInfo", e));
}
}
Aggregations