use of org.dmg.pmml.MiningField in project drools by kiegroup.
the class KiePMMLMiningFieldFactoryTest method getMiningFieldVariableDeclarationNoAllowedValuesNoIntervals.
@Test
public void getMiningFieldVariableDeclarationNoAllowedValuesNoIntervals() throws IOException {
DataField dataField = getRandomDataField();
dataField.getValues().clear();
dataField.getIntervals().clear();
MiningField miningField = new MiningField();
miningField.setName(dataField.getName());
miningField.setUsageType(MiningField.UsageType.TARGET);
BlockStmt retrieved = KiePMMLMiningFieldFactory.getMiningFieldVariableDeclaration(VARIABLE_NAME, miningField, Collections.singletonList(dataField));
String dataTypeString = DATA_TYPE.class.getName() + "." + DATA_TYPE.byName(dataField.getDataType().value()).name();
String text = getFileContent(TEST_01_SOURCE);
Statement expected = JavaParserUtils.parseBlock(String.format(text, VARIABLE_NAME, miningField.getName().getValue(), dataTypeString));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
List<Class<?>> imports = Arrays.asList(Arrays.class, Collections.class, KiePMMLInterval.class, KiePMMLMiningField.class, DATA_TYPE.class);
commonValidateCompilationWithImports(retrieved, imports);
}
use of org.dmg.pmml.MiningField in project drools by kiegroup.
the class KiePMMLMiningFieldFactoryTest method getMiningFieldVariableDeclarationWithAllowedValuesNoIntervals.
@Test
public void getMiningFieldVariableDeclarationWithAllowedValuesNoIntervals() throws IOException {
DataField dataField = getRandomDataField();
dataField.getIntervals().clear();
MiningField miningField = new MiningField();
miningField.setName(dataField.getName());
miningField.setUsageType(MiningField.UsageType.TARGET);
BlockStmt retrieved = KiePMMLMiningFieldFactory.getMiningFieldVariableDeclaration(VARIABLE_NAME, miningField, Collections.singletonList(dataField));
String dataTypeString = DATA_TYPE.class.getName() + "." + DATA_TYPE.byName(dataField.getDataType().value()).name();
String text = getFileContent(TEST_02_SOURCE);
Statement expected = JavaParserUtils.parseBlock(String.format(text, VARIABLE_NAME, miningField.getName().getValue(), dataTypeString, dataField.getValues().get(0).getValue(), dataField.getValues().get(1).getValue(), dataField.getValues().get(2).getValue()));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
List<Class<?>> imports = Arrays.asList(Arrays.class, Collections.class, KiePMMLInterval.class, KiePMMLMiningField.class, DATA_TYPE.class);
commonValidateCompilationWithImports(retrieved, imports);
}
use of org.dmg.pmml.MiningField in project drools by kiegroup.
the class KiePMMLMiningFieldFactoryTest method getMiningFieldVariableDeclarationWithAllowedValuesAndIntervals.
@Test
public void getMiningFieldVariableDeclarationWithAllowedValuesAndIntervals() throws IOException {
DataField dataField = getRandomDataField();
MiningField miningField = new MiningField();
miningField.setName(dataField.getName());
miningField.setUsageType(MiningField.UsageType.TARGET);
BlockStmt retrieved = KiePMMLMiningFieldFactory.getMiningFieldVariableDeclaration(VARIABLE_NAME, miningField, Collections.singletonList(dataField));
String dataTypeString = DATA_TYPE.class.getName() + "." + DATA_TYPE.byName(dataField.getDataType().value()).name();
String text = getFileContent(TEST_03_SOURCE);
Statement expected = JavaParserUtils.parseBlock(String.format(text, VARIABLE_NAME, miningField.getName().getValue(), dataTypeString, dataField.getValues().get(0).getValue(), dataField.getValues().get(1).getValue(), dataField.getValues().get(2).getValue(), dataField.getIntervals().get(0).getLeftMargin(), dataField.getIntervals().get(0).getRightMargin(), dataField.getIntervals().get(0).getClosure().name(), dataField.getIntervals().get(1).getLeftMargin(), dataField.getIntervals().get(1).getRightMargin(), dataField.getIntervals().get(1).getClosure().name(), dataField.getIntervals().get(2).getLeftMargin(), dataField.getIntervals().get(2).getRightMargin(), dataField.getIntervals().get(2).getClosure().name()));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
List<Class<?>> imports = Arrays.asList(Arrays.class, Collections.class, KiePMMLInterval.class, KiePMMLMiningField.class, DATA_TYPE.class);
commonValidateCompilationWithImports(retrieved, imports);
}
use of org.dmg.pmml.MiningField in project drools by kiegroup.
the class ModelUtilsTest method getOpTypeByTargetsNotFound.
@Test(expected = KiePMMLInternalException.class)
public void getOpTypeByTargetsNotFound() {
final Model model = new RegressionModel();
final DataDictionary dataDictionary = new DataDictionary();
final MiningSchema miningSchema = new MiningSchema();
final Targets targets = new Targets();
IntStream.range(0, 3).forEach(i -> {
String fieldName = "field" + i;
final DataField dataField = getRandomDataField();
dataField.setName(FieldName.create(fieldName));
dataDictionary.addDataFields(dataField);
final MiningField miningField = getRandomMiningField();
miningField.setName(dataField.getName());
miningSchema.addMiningFields(miningField);
final Target targetField = getRandomTarget();
targetField.setField(dataField.getName());
targets.addTargets(targetField);
});
model.setMiningSchema(miningSchema);
model.setTargets(targets);
ModelUtils.getOpType(getFieldsFromDataDictionary(dataDictionary), model, "NOT_EXISTING");
}
use of org.dmg.pmml.MiningField in project drools by kiegroup.
the class ModelUtilsTest method getTargetFieldsWithTargetFieldsWithoutOptType.
@Test
public void getTargetFieldsWithTargetFieldsWithoutOptType() {
final Model model = new RegressionModel();
final DataDictionary dataDictionary = new DataDictionary();
final MiningSchema miningSchema = new MiningSchema();
IntStream.range(0, 3).forEach(i -> {
final String fieldName = "fieldName-" + i;
final DataField dataField = getDataField(fieldName, OpType.CATEGORICAL, DataType.STRING);
dataDictionary.addDataFields(dataField);
final MiningField miningField = getMiningField(fieldName, MiningField.UsageType.PREDICTED);
miningField.setOpType(null);
miningSchema.addMiningFields(miningField);
});
model.setMiningSchema(miningSchema);
List<KiePMMLNameOpType> retrieved = ModelUtils.getTargetFields(getFieldsFromDataDictionary(dataDictionary), model);
assertNotNull(retrieved);
assertEquals(miningSchema.getMiningFields().size(), retrieved.size());
retrieved.forEach(kiePMMLNameOpType -> {
assertTrue(miningSchema.getMiningFields().stream().anyMatch(fld -> kiePMMLNameOpType.getName().equals(fld.getName().getValue())));
Optional<DataField> optionalDataField = dataDictionary.getDataFields().stream().filter(fld -> kiePMMLNameOpType.getName().equals(fld.getName().getValue())).findFirst();
assertTrue(optionalDataField.isPresent());
DataField dataField = optionalDataField.get();
OP_TYPE expected = OP_TYPE.byName(dataField.getOpType().value());
assertEquals(expected, kiePMMLNameOpType.getOpType());
});
}
Aggregations