use of org.dmg.pmml.MiningField in project drools by kiegroup.
the class KiePMMLUtilTest method populateMissingPredictedOutputFieldTarget.
@Test
public void populateMissingPredictedOutputFieldTarget() throws Exception {
final InputStream inputStream = getFileInputStream(NO_OUTPUT_FIELD_TARGET_NAME_SAMPLE);
final PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(inputStream);
final Model toPopulate = pmml.getModels().get(0);
final OutputField outputField = toPopulate.getOutput().getOutputFields().get(0);
assertEquals(ResultFeature.PREDICTED_VALUE, outputField.getResultFeature());
assertNull(outputField.getTargetField());
KiePMMLUtil.populateMissingPredictedOutputFieldTarget(toPopulate);
final MiningField targetField = getMiningTargetFields(toPopulate.getMiningSchema().getMiningFields()).get(0);
assertNotNull(outputField.getTargetField());
assertEquals(targetField.getName(), outputField.getTargetField());
}
use of org.dmg.pmml.MiningField in project drools by kiegroup.
the class KiePMMLUtilTest method populateMissingTargetFieldInSegment.
@Test
public void populateMissingTargetFieldInSegment() throws Exception {
final InputStream inputStream = getFileInputStream(NO_MODELNAME_NO_SEGMENT_ID_NOSEGMENT_TARGET_FIELD_SAMPLE);
final PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(inputStream);
final Model retrieved = pmml.getModels().get(0);
assertTrue(retrieved instanceof MiningModel);
MiningModel miningModel = (MiningModel) retrieved;
Model toPopulate = miningModel.getSegmentation().getSegments().get(0).getModel();
assertTrue(getMiningTargetFields(toPopulate.getMiningSchema()).isEmpty());
KiePMMLUtil.populateMissingTargetFieldInSegment(retrieved.getMiningSchema(), toPopulate);
List<MiningField> childrenTargetFields = getMiningTargetFields(toPopulate.getMiningSchema());
assertFalse(childrenTargetFields.isEmpty());
getMiningTargetFields(miningModel.getMiningSchema()).forEach(parentTargetField -> assertTrue(childrenTargetFields.contains(parentTargetField)));
}
use of org.dmg.pmml.MiningField in project drools by kiegroup.
the class KiePMMLUtilTest method correctTargetFields.
@Test
public void correctTargetFields() {
final MiningField miningField = new MiningField(FieldName.create("FIELD_NAME"));
final Targets targets = new Targets();
final Target namedTarget = new Target();
String targetName = "TARGET_NAME";
namedTarget.setField(FieldName.create(targetName));
final Target unnamedTarget = new Target();
targets.addTargets(namedTarget, unnamedTarget);
KiePMMLUtil.correctTargetFields(miningField, targets);
assertEquals(targetName, namedTarget.getField().getValue());
assertEquals(miningField.getName(), unnamedTarget.getField());
}
use of org.dmg.pmml.MiningField in project drools by kiegroup.
the class KiePMMLUtilTest method getTargetMiningField.
@Test
public void getTargetMiningField() {
final DataField dataField = new DataField();
dataField.setName(FieldName.create("FIELD_NAME"));
final MiningField retrieved = KiePMMLUtil.getTargetMiningField(dataField);
assertEquals(dataField.getName().getValue(), retrieved.getName().getValue());
assertEquals(MiningField.UsageType.TARGET, retrieved.getUsageType());
}
use of org.dmg.pmml.MiningField in project drools by kiegroup.
the class KiePMMLClassificationTableFactoryTest method getClassificationTableBuilders.
@Test
public void getClassificationTableBuilders() {
RegressionTable regressionTableProf = getRegressionTable(3.5, "professional");
RegressionTable regressionTableCler = getRegressionTable(27.4, "clerical");
OutputField outputFieldCat = getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
OutputField outputFieldNum = getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
OutputField outputFieldPrev = getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
String targetField = "targetField";
DataField dataField = new DataField();
dataField.setName(FieldName.create(targetField));
dataField.setOpType(OpType.CATEGORICAL);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
RegressionModel regressionModel = new RegressionModel();
regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
regressionModel.addRegressionTables(regressionTableProf, regressionTableCler);
regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
Output output = new Output();
output.addOutputFields(outputFieldCat, outputFieldNum, outputFieldPrev);
regressionModel.setOutput(output);
MiningField miningField = new MiningField();
miningField.setUsageType(MiningField.UsageType.TARGET);
miningField.setName(dataField.getName());
MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
regressionModel.setMiningSchema(miningSchema);
PMML pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.addModels(regressionModel);
final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
Map<String, KiePMMLTableSourceCategory> retrieved = KiePMMLClassificationTableFactory.getClassificationTableBuilders(compilationDTO);
assertNotNull(retrieved);
assertEquals(3, retrieved.size());
retrieved.values().forEach(kiePMMLTableSourceCategory -> commonValidateKiePMMLRegressionTable(kiePMMLTableSourceCategory.getSource()));
Map<String, String> sources = retrieved.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, stringKiePMMLTableSourceCategoryEntry -> stringKiePMMLTableSourceCategoryEntry.getValue().getSource()));
commonValidateCompilation(sources);
}
Aggregations