use of org.dmg.pmml.MiningField in project drools by kiegroup.
the class ModelUtilsTest method getTargetFieldName.
@Test
public void getTargetFieldName() {
final String fieldName = "fieldName";
MiningField.UsageType usageType = MiningField.UsageType.ACTIVE;
MiningField miningField = getMiningField(fieldName, usageType);
final DataField dataField = getDataField(fieldName, OpType.CATEGORICAL, DataType.STRING);
final DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
final Model model = new RegressionModel();
model.setMiningSchema(miningSchema);
final List<Field<?>> fields = getFieldsFromDataDictionary(dataDictionary);
Optional<String> retrieved = ModelUtils.getTargetFieldName(fields, model);
assertFalse(retrieved.isPresent());
usageType = MiningField.UsageType.PREDICTED;
miningField = getMiningField(fieldName, usageType);
miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
model.setMiningSchema(miningSchema);
retrieved = ModelUtils.getTargetFieldName(fields, model);
assertTrue(retrieved.isPresent());
assertEquals(fieldName, retrieved.get());
}
use of org.dmg.pmml.MiningField in project drools by kiegroup.
the class KiePMMLMiningFieldFactory method getMiningFieldVariableDeclaration.
static BlockStmt getMiningFieldVariableDeclaration(final String variableName, final MiningField miningField, final List<Field<?>> fields) {
final MethodDeclaration methodDeclaration = MININGFIELD_TEMPLATE.getMethodsByName(GETKIEPMMLMININGFIELD).get(0).clone();
final BlockStmt miningFieldBody = methodDeclaration.getBody().orElseThrow(() -> new KiePMMLException(String.format(MISSING_BODY_TEMPLATE, methodDeclaration)));
final VariableDeclarator variableDeclarator = getVariableDeclarator(miningFieldBody, MININGFIELD).orElseThrow(() -> new KiePMMLException(String.format(MISSING_VARIABLE_IN_BODY, MININGFIELD, miningFieldBody)));
variableDeclarator.setName(variableName);
final BlockStmt toReturn = new BlockStmt();
final MethodCallExpr initializer = variableDeclarator.getInitializer().orElseThrow(() -> new KiePMMLException(String.format(MISSING_VARIABLE_INITIALIZER_TEMPLATE, MININGFIELD, toReturn))).asMethodCallExpr();
final MethodCallExpr builder = getChainedMethodCallExprFrom("builder", initializer);
final StringLiteralExpr nameExpr = new StringLiteralExpr(miningField.getName().getValue());
Expression fieldUsageTypeExpr;
if (miningField.getUsageType() != null) {
final FIELD_USAGE_TYPE fieldUsageType = FIELD_USAGE_TYPE.byName(miningField.getUsageType().value());
fieldUsageTypeExpr = new NameExpr(FIELD_USAGE_TYPE.class.getName() + "." + fieldUsageType.name());
} else {
fieldUsageTypeExpr = new NullLiteralExpr();
}
Expression opTypeExpr;
if (miningField.getOpType() != null) {
final OP_TYPE opType = OP_TYPE.byName(miningField.getOpType().value());
opTypeExpr = new NameExpr(OP_TYPE.class.getName() + "." + opType.name());
} else {
opTypeExpr = new NullLiteralExpr();
}
final List<Field<?>> mappedFields = getMappedFields(fields, miningField.getName().getValue());
final DataType dataType = getDataType(mappedFields, miningField.getName().getValue());
final DATA_TYPE data_TYPE = DATA_TYPE.byName(dataType.value());
Expression dataTypeExpr = new NameExpr(DATA_TYPE.class.getName() + "." + data_TYPE.name());
Expression missingValueTreatmentMethodExpr;
if (miningField.getMissingValueTreatment() != null) {
final MISSING_VALUE_TREATMENT_METHOD missingValueTreatmentMethod = MISSING_VALUE_TREATMENT_METHOD.byName(miningField.getMissingValueTreatment().value());
missingValueTreatmentMethodExpr = new NameExpr(MISSING_VALUE_TREATMENT_METHOD.class.getName() + "." + missingValueTreatmentMethod.name());
} else {
missingValueTreatmentMethodExpr = new NullLiteralExpr();
}
Expression invalidValueTreatmentMethodExpr;
if (miningField.getInvalidValueTreatment() != null) {
final INVALID_VALUE_TREATMENT_METHOD invalidValueTreatmentMethod = INVALID_VALUE_TREATMENT_METHOD.byName(miningField.getInvalidValueTreatment().value());
invalidValueTreatmentMethodExpr = new NameExpr(INVALID_VALUE_TREATMENT_METHOD.class.getName() + "." + invalidValueTreatmentMethod.name());
} else {
invalidValueTreatmentMethodExpr = new NullLiteralExpr();
}
Expression missingValueReplacementExpr;
if (miningField.getMissingValueReplacement() != null) {
final String missingValueReplacement = miningField.getMissingValueReplacement().toString();
missingValueReplacementExpr = new StringLiteralExpr(missingValueReplacement);
} else {
missingValueReplacementExpr = new NullLiteralExpr();
}
Expression invalidValueReplacementExpr;
if (miningField.getInvalidValueReplacement() != null) {
final String invalidValueReplacement = miningField.getInvalidValueReplacement().toString();
invalidValueReplacementExpr = new StringLiteralExpr(invalidValueReplacement);
} else {
invalidValueReplacementExpr = new NullLiteralExpr();
}
DataField dataField = getMappedDataField(mappedFields);
NodeList<Expression> allowedValuesExpressions = dataField != null ? getAllowedValuesExpressions(dataField) : new NodeList<>();
NodeList<Expression> intervalsExpressions = dataField != null ? getIntervalsExpressions(dataField) : new NodeList<>();
builder.setArgument(0, nameExpr);
getChainedMethodCallExprFrom("withFieldUsageType", initializer).setArgument(0, fieldUsageTypeExpr);
getChainedMethodCallExprFrom("withOpType", initializer).setArgument(0, opTypeExpr);
getChainedMethodCallExprFrom("withDataType", initializer).setArgument(0, dataTypeExpr);
getChainedMethodCallExprFrom("withMissingValueTreatmentMethod", initializer).setArgument(0, missingValueTreatmentMethodExpr);
getChainedMethodCallExprFrom("withInvalidValueTreatmentMethod", initializer).setArgument(0, invalidValueTreatmentMethodExpr);
getChainedMethodCallExprFrom("withMissingValueReplacement", initializer).setArgument(0, missingValueReplacementExpr);
getChainedMethodCallExprFrom("withInvalidValueReplacement", initializer).setArgument(0, invalidValueReplacementExpr);
getChainedMethodCallExprFrom("withAllowedValues", initializer).getArgument(0).asMethodCallExpr().setArguments(allowedValuesExpressions);
getChainedMethodCallExprFrom("withIntervals", initializer).getArgument(0).asMethodCallExpr().setArguments(intervalsExpressions);
miningFieldBody.getStatements().forEach(toReturn::addStatement);
return toReturn;
}
use of org.dmg.pmml.MiningField in project drools by kiegroup.
the class PMMLModelTestUtils method getMiningField.
public static MiningField getMiningField(String fieldName, MiningField.UsageType usageType) {
MiningField toReturn = getRandomMiningField();
toReturn.setName(FieldName.create(fieldName));
toReturn.setUsageType(usageType);
return toReturn;
}
use of org.dmg.pmml.MiningField in project drools by kiegroup.
the class ModelUtilsTest method getOpTypeByMiningFieldsNotFound.
@Test(expected = KiePMMLInternalException.class)
public void getOpTypeByMiningFieldsNotFound() {
final Model model = new RegressionModel();
final DataDictionary dataDictionary = new DataDictionary();
final MiningSchema miningSchema = new MiningSchema();
IntStream.range(0, 3).forEach(i -> {
String fieldName = "field" + i;
final DataField dataField = getRandomDataField();
dataField.setName(FieldName.create(fieldName));
dataDictionary.addDataFields(dataField);
final MiningField miningField = getRandomMiningField();
miningField.setName(dataField.getName());
miningSchema.addMiningFields(miningField);
});
model.setMiningSchema(miningSchema);
ModelUtils.getOpType(getFieldsFromDataDictionary(dataDictionary), model, "NOT_EXISTING");
}
use of org.dmg.pmml.MiningField in project drools by kiegroup.
the class KiePMMLUtilTest method populateMissingOutputFieldDataType.
@Test
public void populateMissingOutputFieldDataType() {
Random random = new Random();
List<String> fieldNames = IntStream.range(0, 6).mapToObj(i -> RandomStringUtils.random(6, true, false)).collect(Collectors.toList());
List<DataField> dataFields = fieldNames.stream().map(fieldName -> {
DataField toReturn = new DataField();
toReturn.setName(FieldName.create(fieldName));
DataType dataType = DataType.values()[random.nextInt(DataType.values().length)];
toReturn.setDataType(dataType);
return toReturn;
}).collect(Collectors.toList());
List<MiningField> miningFields = IntStream.range(0, dataFields.size() - 1).mapToObj(dataFields::get).map(dataField -> {
MiningField toReturn = new MiningField();
toReturn.setName(FieldName.create(dataField.getName().getValue()));
toReturn.setUsageType(MiningField.UsageType.ACTIVE);
return toReturn;
}).collect(Collectors.toList());
DataField lastDataField = dataFields.get(dataFields.size() - 1);
MiningField targetMiningField = new MiningField();
targetMiningField.setName(FieldName.create(lastDataField.getName().getValue()));
targetMiningField.setUsageType(MiningField.UsageType.TARGET);
miningFields.add(targetMiningField);
// Following OutputFields should be populated based on "ResultFeature.PROBABILITY"
List<OutputField> outputFields = IntStream.range(0, 3).mapToObj(i -> {
OutputField toReturn = new OutputField();
toReturn.setName(FieldName.create(RandomStringUtils.random(6, true, false)));
toReturn.setResultFeature(ResultFeature.PROBABILITY);
return toReturn;
}).collect(Collectors.toList());
// Following OutputField should be populated based on "ResultFeature.PREDICTED_VALUE"
OutputField targetOutputField = new OutputField();
targetOutputField.setName(FieldName.create(RandomStringUtils.random(6, true, false)));
targetOutputField.setResultFeature(ResultFeature.PREDICTED_VALUE);
outputFields.add(targetOutputField);
// Following OutputField should be populated based on "TargetField" property
OutputField targetingOutputField = new OutputField();
targetingOutputField.setName(FieldName.create(RandomStringUtils.random(6, true, false)));
targetingOutputField.setTargetField(FieldName.create(targetMiningField.getName().getValue()));
outputFields.add(targetingOutputField);
outputFields.forEach(outputField -> assertNull(outputField.getDataType()));
IntStream.range(0, 2).forEach(i -> {
OutputField toAdd = new OutputField();
toAdd.setName(FieldName.create(RandomStringUtils.random(6, true, false)));
DataType dataType = DataType.values()[random.nextInt(DataType.values().length)];
toAdd.setDataType(dataType);
outputFields.add(toAdd);
});
KiePMMLUtil.populateMissingOutputFieldDataType(outputFields, miningFields, dataFields);
outputFields.forEach(outputField -> assertNotNull(outputField.getDataType()));
}
Aggregations