use of org.dmg.pmml.MiningField in project drools by kiegroup.
the class ModelUtilsTest method getOpTypeByTargets.
@Test
public void getOpTypeByTargets() {
final Model model = new RegressionModel();
final DataDictionary dataDictionary = new DataDictionary();
final MiningSchema miningSchema = new MiningSchema();
final Targets targets = new Targets();
IntStream.range(0, 3).forEach(i -> {
final DataField dataField = getRandomDataField();
dataDictionary.addDataFields(dataField);
final MiningField miningField = getRandomMiningField();
miningField.setName(dataField.getName());
miningSchema.addMiningFields(miningField);
final Target targetField = getRandomTarget();
targetField.setField(dataField.getName());
targets.addTargets(targetField);
});
model.setMiningSchema(miningSchema);
model.setTargets(targets);
getFieldsFromDataDictionary(dataDictionary);
targets.getTargets().forEach(target -> {
OP_TYPE retrieved = ModelUtils.getOpType(getFieldsFromDataDictionary(dataDictionary), model, target.getField().getValue());
assertNotNull(retrieved);
OP_TYPE expected = OP_TYPE.byName(target.getOpType().value());
assertEquals(expected, retrieved);
});
}
use of org.dmg.pmml.MiningField in project drools by kiegroup.
the class KiePMMLUtilTest method getMiningTargetFieldsFromMiningSchema.
@Test
public void getMiningTargetFieldsFromMiningSchema() throws Exception {
final InputStream inputStream = getFileInputStream(NO_MODELNAME_SAMPLE_NAME);
final PMML toPopulate = org.jpmml.model.PMMLUtil.unmarshal(inputStream);
final Model model = toPopulate.getModels().get(0);
List<MiningField> retrieved = KiePMMLUtil.getMiningTargetFields(model.getMiningSchema());
assertNotNull(retrieved);
assertEquals(1, retrieved.size());
MiningField targetField = retrieved.get(0);
assertEquals("car_location", targetField.getName().getValue());
assertEquals("target", targetField.getUsageType().value());
}
use of org.dmg.pmml.MiningField in project drools by kiegroup.
the class KiePMMLUtilTest method populateMissingMiningTargetField.
@Test
public void populateMissingMiningTargetField() throws Exception {
final InputStream inputStream = getFileInputStream(NO_TARGET_FIELD_SAMPLE);
final PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(inputStream);
final Model toPopulate = pmml.getModels().get(0);
List<MiningField> miningTargetFields = getMiningTargetFields(toPopulate.getMiningSchema().getMiningFields());
assertTrue(miningTargetFields.isEmpty());
assertNull(toPopulate.getTargets().getTargets().get(0).getField());
KiePMMLUtil.populateMissingMiningTargetField(toPopulate, pmml.getDataDictionary().getDataFields());
miningTargetFields = getMiningTargetFields(toPopulate.getMiningSchema().getMiningFields());
assertEquals(1, miningTargetFields.size());
final MiningField targetField = miningTargetFields.get(0);
assertTrue(pmml.getDataDictionary().getDataFields().stream().anyMatch(dataField -> dataField.getName().equals(targetField.getName())));
assertEquals(targetField.getName(), toPopulate.getTargets().getTargets().get(0).getField());
}
use of org.dmg.pmml.MiningField in project drools by kiegroup.
the class KiePMMLUtilTest method getMiningTargetFieldsFromMiningFields.
@Test
public void getMiningTargetFieldsFromMiningFields() throws Exception {
final InputStream inputStream = getFileInputStream(NO_MODELNAME_SAMPLE_NAME);
final PMML toPopulate = org.jpmml.model.PMMLUtil.unmarshal(inputStream);
final Model model = toPopulate.getModels().get(0);
List<MiningField> retrieved = KiePMMLUtil.getMiningTargetFields(model.getMiningSchema().getMiningFields());
assertNotNull(retrieved);
assertEquals(1, retrieved.size());
MiningField targetField = retrieved.get(0);
assertEquals("car_location", targetField.getName().getValue());
assertEquals("target", targetField.getUsageType().value());
}
use of org.dmg.pmml.MiningField in project drools by kiegroup.
the class KiePMMLDroolsModelFactoryUtilsTest method getKiePMMLModelCompilationUnit.
@Test
public void getKiePMMLModelCompilationUnit() {
DataDictionary dataDictionary = new DataDictionary();
String targetFieldString = "target.field";
FieldName targetFieldName = FieldName.create(targetFieldString);
dataDictionary.addDataFields(new DataField(targetFieldName, OpType.CONTINUOUS, DataType.DOUBLE));
String modelName = "ModelName";
TreeModel model = new TreeModel();
model.setModelName(modelName);
model.setMiningFunction(MiningFunction.CLASSIFICATION);
MiningField targetMiningField = new MiningField(targetFieldName);
targetMiningField.setUsageType(MiningField.UsageType.TARGET);
MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(targetMiningField);
model.setMiningSchema(miningSchema);
Map<String, KiePMMLOriginalTypeGeneratedType> fieldTypeMap = new HashMap<>();
fieldTypeMap.put(targetFieldString, new KiePMMLOriginalTypeGeneratedType(targetFieldString, getSanitizedClassName(targetFieldString)));
String packageName = "net.test";
PMML pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.addModels(model);
final CommonCompilationDTO<TreeModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(packageName, pmml, model, new HasClassLoaderMock());
final DroolsCompilationDTO<TreeModel> droolsCompilationDTO = DroolsCompilationDTO.fromCompilationDTO(source, fieldTypeMap);
CompilationUnit retrieved = KiePMMLDroolsModelFactoryUtils.getKiePMMLModelCompilationUnit(droolsCompilationDTO, TEMPLATE_SOURCE, TEMPLATE_CLASS_NAME);
assertEquals(droolsCompilationDTO.getPackageName(), retrieved.getPackageDeclaration().get().getNameAsString());
ConstructorDeclaration constructorDeclaration = retrieved.getClassByName(modelName).get().getDefaultConstructor().get();
MINING_FUNCTION miningFunction = MINING_FUNCTION.CLASSIFICATION;
PMML_MODEL pmmlModel = PMML_MODEL.byName(model.getClass().getSimpleName());
Map<String, Expression> assignExpressionMap = new HashMap<>();
assignExpressionMap.put("targetField", new StringLiteralExpr(targetFieldString));
assignExpressionMap.put("miningFunction", new NameExpr(miningFunction.getClass().getName() + "." + miningFunction.name()));
assignExpressionMap.put("pmmlMODEL", new NameExpr(pmmlModel.getClass().getName() + "." + pmmlModel.name()));
String expectedKModulePackageName = getSanitizedPackageName(packageName + "." + modelName);
assignExpressionMap.put("kModulePackageName", new StringLiteralExpr(expectedKModulePackageName));
assertTrue(commonEvaluateAssignExpr(constructorDeclaration.getBody(), assignExpressionMap));
// The last "1" is for
int expectedMethodCallExprs = assignExpressionMap.size() + fieldTypeMap.size() + 1;
// the super invocation
commonEvaluateFieldTypeMap(constructorDeclaration.getBody(), fieldTypeMap, expectedMethodCallExprs);
}
Aggregations