use of org.kie.pmml.api.enums.DATA_TYPE in project drools by kiegroup.
the class ModelUtilsTest method getTargetFieldsTypeMapWithTargetFieldsWithoutTargets.
@Test
public void getTargetFieldsTypeMapWithTargetFieldsWithoutTargets() {
final Model model = new RegressionModel();
final DataDictionary dataDictionary = new DataDictionary();
final MiningSchema miningSchema = new MiningSchema();
IntStream.range(0, 3).forEach(i -> {
final DataField dataField = getRandomDataField();
dataDictionary.addDataFields(dataField);
final MiningField miningField = getMiningField(dataField.getName().getValue(), MiningField.UsageType.PREDICTED);
miningSchema.addMiningFields(miningField);
});
model.setMiningSchema(miningSchema);
Map<String, DATA_TYPE> retrieved = ModelUtils.getTargetFieldsTypeMap(getFieldsFromDataDictionary(dataDictionary), model);
assertNotNull(retrieved);
assertEquals(miningSchema.getMiningFields().size(), retrieved.size());
assertTrue(retrieved instanceof LinkedHashMap);
final Iterator<Map.Entry<String, DATA_TYPE>> iterator = retrieved.entrySet().iterator();
for (int i = 0; i < miningSchema.getMiningFields().size(); i++) {
MiningField miningField = miningSchema.getMiningFields().get(i);
DataField dataField = dataDictionary.getDataFields().stream().filter(df -> df.getName().equals(miningField.getName())).findFirst().get();
DATA_TYPE expected = DATA_TYPE.byName(dataField.getDataType().value());
final Map.Entry<String, DATA_TYPE> next = iterator.next();
assertEquals(expected, next.getValue());
}
}
use of org.kie.pmml.api.enums.DATA_TYPE in project drools by kiegroup.
the class KiePMMLMiningFieldFactory method getMiningFieldVariableDeclaration.
static BlockStmt getMiningFieldVariableDeclaration(final String variableName, final MiningField miningField, final List<Field<?>> fields) {
final MethodDeclaration methodDeclaration = MININGFIELD_TEMPLATE.getMethodsByName(GETKIEPMMLMININGFIELD).get(0).clone();
final BlockStmt miningFieldBody = methodDeclaration.getBody().orElseThrow(() -> new KiePMMLException(String.format(MISSING_BODY_TEMPLATE, methodDeclaration)));
final VariableDeclarator variableDeclarator = getVariableDeclarator(miningFieldBody, MININGFIELD).orElseThrow(() -> new KiePMMLException(String.format(MISSING_VARIABLE_IN_BODY, MININGFIELD, miningFieldBody)));
variableDeclarator.setName(variableName);
final BlockStmt toReturn = new BlockStmt();
final MethodCallExpr initializer = variableDeclarator.getInitializer().orElseThrow(() -> new KiePMMLException(String.format(MISSING_VARIABLE_INITIALIZER_TEMPLATE, MININGFIELD, toReturn))).asMethodCallExpr();
final MethodCallExpr builder = getChainedMethodCallExprFrom("builder", initializer);
final StringLiteralExpr nameExpr = new StringLiteralExpr(miningField.getName().getValue());
Expression fieldUsageTypeExpr;
if (miningField.getUsageType() != null) {
final FIELD_USAGE_TYPE fieldUsageType = FIELD_USAGE_TYPE.byName(miningField.getUsageType().value());
fieldUsageTypeExpr = new NameExpr(FIELD_USAGE_TYPE.class.getName() + "." + fieldUsageType.name());
} else {
fieldUsageTypeExpr = new NullLiteralExpr();
}
Expression opTypeExpr;
if (miningField.getOpType() != null) {
final OP_TYPE opType = OP_TYPE.byName(miningField.getOpType().value());
opTypeExpr = new NameExpr(OP_TYPE.class.getName() + "." + opType.name());
} else {
opTypeExpr = new NullLiteralExpr();
}
final List<Field<?>> mappedFields = getMappedFields(fields, miningField.getName().getValue());
final DataType dataType = getDataType(mappedFields, miningField.getName().getValue());
final DATA_TYPE data_TYPE = DATA_TYPE.byName(dataType.value());
Expression dataTypeExpr = new NameExpr(DATA_TYPE.class.getName() + "." + data_TYPE.name());
Expression missingValueTreatmentMethodExpr;
if (miningField.getMissingValueTreatment() != null) {
final MISSING_VALUE_TREATMENT_METHOD missingValueTreatmentMethod = MISSING_VALUE_TREATMENT_METHOD.byName(miningField.getMissingValueTreatment().value());
missingValueTreatmentMethodExpr = new NameExpr(MISSING_VALUE_TREATMENT_METHOD.class.getName() + "." + missingValueTreatmentMethod.name());
} else {
missingValueTreatmentMethodExpr = new NullLiteralExpr();
}
Expression invalidValueTreatmentMethodExpr;
if (miningField.getInvalidValueTreatment() != null) {
final INVALID_VALUE_TREATMENT_METHOD invalidValueTreatmentMethod = INVALID_VALUE_TREATMENT_METHOD.byName(miningField.getInvalidValueTreatment().value());
invalidValueTreatmentMethodExpr = new NameExpr(INVALID_VALUE_TREATMENT_METHOD.class.getName() + "." + invalidValueTreatmentMethod.name());
} else {
invalidValueTreatmentMethodExpr = new NullLiteralExpr();
}
Expression missingValueReplacementExpr;
if (miningField.getMissingValueReplacement() != null) {
final String missingValueReplacement = miningField.getMissingValueReplacement().toString();
missingValueReplacementExpr = new StringLiteralExpr(missingValueReplacement);
} else {
missingValueReplacementExpr = new NullLiteralExpr();
}
Expression invalidValueReplacementExpr;
if (miningField.getInvalidValueReplacement() != null) {
final String invalidValueReplacement = miningField.getInvalidValueReplacement().toString();
invalidValueReplacementExpr = new StringLiteralExpr(invalidValueReplacement);
} else {
invalidValueReplacementExpr = new NullLiteralExpr();
}
DataField dataField = getMappedDataField(mappedFields);
NodeList<Expression> allowedValuesExpressions = dataField != null ? getAllowedValuesExpressions(dataField) : new NodeList<>();
NodeList<Expression> intervalsExpressions = dataField != null ? getIntervalsExpressions(dataField) : new NodeList<>();
builder.setArgument(0, nameExpr);
getChainedMethodCallExprFrom("withFieldUsageType", initializer).setArgument(0, fieldUsageTypeExpr);
getChainedMethodCallExprFrom("withOpType", initializer).setArgument(0, opTypeExpr);
getChainedMethodCallExprFrom("withDataType", initializer).setArgument(0, dataTypeExpr);
getChainedMethodCallExprFrom("withMissingValueTreatmentMethod", initializer).setArgument(0, missingValueTreatmentMethodExpr);
getChainedMethodCallExprFrom("withInvalidValueTreatmentMethod", initializer).setArgument(0, invalidValueTreatmentMethodExpr);
getChainedMethodCallExprFrom("withMissingValueReplacement", initializer).setArgument(0, missingValueReplacementExpr);
getChainedMethodCallExprFrom("withInvalidValueReplacement", initializer).setArgument(0, invalidValueReplacementExpr);
getChainedMethodCallExprFrom("withAllowedValues", initializer).getArgument(0).asMethodCallExpr().setArguments(allowedValuesExpressions);
getChainedMethodCallExprFrom("withIntervals", initializer).getArgument(0).asMethodCallExpr().setArguments(intervalsExpressions);
miningFieldBody.getStatements().forEach(toReturn::addStatement);
return toReturn;
}
use of org.kie.pmml.api.enums.DATA_TYPE in project drools by kiegroup.
the class ModelUtilsTest method getTargetFieldsTypeMapWithoutTargetFieldsWithoutTargets.
@Test
public void getTargetFieldsTypeMapWithoutTargetFieldsWithoutTargets() {
final Model model = new RegressionModel();
final DataDictionary dataDictionary = new DataDictionary();
final MiningSchema miningSchema = new MiningSchema();
IntStream.range(0, 3).forEach(i -> {
final DataField dataField = getRandomDataField();
dataDictionary.addDataFields(dataField);
final MiningField miningField = getMiningField(dataField.getName().getValue(), MiningField.UsageType.ACTIVE);
miningSchema.addMiningFields(miningField);
});
model.setMiningSchema(miningSchema);
Map<String, DATA_TYPE> retrieved = ModelUtils.getTargetFieldsTypeMap(getFieldsFromDataDictionary(dataDictionary), model);
assertNotNull(retrieved);
assertTrue(retrieved.isEmpty());
}
use of org.kie.pmml.api.enums.DATA_TYPE in project drools by kiegroup.
the class PMMLModelTestUtils method getParameterFields.
public static List<ParameterField> getParameterFields() {
DATA_TYPE[] dataTypes = DATA_TYPE.values();
List<ParameterField> toReturn = new ArrayList<>();
for (int i = 0; i < dataTypes.length; i++) {
DataType dataType = DataType.fromValue(dataTypes[i].getName());
ParameterField toAdd = getParameterField(dataType.value().toUpperCase(), dataType);
toReturn.add(toAdd);
}
return toReturn;
}
use of org.kie.pmml.api.enums.DATA_TYPE in project drools by kiegroup.
the class KiePMMLDefineFunctionInstanceFactory method getKiePMMLDefineFunction.
static KiePMMLDefineFunction getKiePMMLDefineFunction(final DefineFunction defineFunction) {
final List<KiePMMLParameterField> kiePMMLParameterFields = getKiePMMLParameterFields(defineFunction.getParameterFields());
DATA_TYPE dataType = defineFunction.getDataType() != null ? DATA_TYPE.byName(defineFunction.getDataType().value()) : null;
OP_TYPE opType = defineFunction.getOpType() != null ? OP_TYPE.byName(defineFunction.getOpType().value()) : null;
return new KiePMMLDefineFunction(defineFunction.getName(), getKiePMMLExtensions(defineFunction.getExtensions()), dataType, opType, kiePMMLParameterFields, getKiePMMLExpression(defineFunction.getExpression()));
}
Aggregations