Search in sources :

Example 41 with MiningField

use of org.dmg.pmml.MiningField in project drools by kiegroup.

the class ModelUtilsTest method getOpTypeByTargets.

@Test
public void getOpTypeByTargets() {
    final Model model = new RegressionModel();
    final DataDictionary dataDictionary = new DataDictionary();
    final MiningSchema miningSchema = new MiningSchema();
    final Targets targets = new Targets();
    IntStream.range(0, 3).forEach(i -> {
        final DataField dataField = getRandomDataField();
        dataDictionary.addDataFields(dataField);
        final MiningField miningField = getRandomMiningField();
        miningField.setName(dataField.getName());
        miningSchema.addMiningFields(miningField);
        final Target targetField = getRandomTarget();
        targetField.setField(dataField.getName());
        targets.addTargets(targetField);
    });
    model.setMiningSchema(miningSchema);
    model.setTargets(targets);
    getFieldsFromDataDictionary(dataDictionary);
    targets.getTargets().forEach(target -> {
        OP_TYPE retrieved = ModelUtils.getOpType(getFieldsFromDataDictionary(dataDictionary), model, target.getField().getValue());
        assertNotNull(retrieved);
        OP_TYPE expected = OP_TYPE.byName(target.getOpType().value());
        assertEquals(expected, retrieved);
    });
}
Also used : PMMLModelTestUtils.getMiningField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getMiningField) MiningField(org.dmg.pmml.MiningField) PMMLModelTestUtils.getRandomMiningField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomMiningField) Target(org.dmg.pmml.Target) PMMLModelTestUtils.getRandomTarget(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomTarget) PMMLModelTestUtils.getTarget(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getTarget) MiningSchema(org.dmg.pmml.MiningSchema) DataField(org.dmg.pmml.DataField) PMMLModelTestUtils.getDataField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getDataField) PMMLModelTestUtils.getRandomDataField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomDataField) Model(org.dmg.pmml.Model) RegressionModel(org.dmg.pmml.regression.RegressionModel) Targets(org.dmg.pmml.Targets) OP_TYPE(org.kie.pmml.api.enums.OP_TYPE) DataDictionary(org.dmg.pmml.DataDictionary) CommonTestingUtils.getFieldsFromDataDictionary(org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary) RegressionModel(org.dmg.pmml.regression.RegressionModel) Test(org.junit.Test)

Example 42 with MiningField

use of org.dmg.pmml.MiningField in project drools by kiegroup.

the class KiePMMLUtilTest method getMiningTargetFieldsFromMiningSchema.

@Test
public void getMiningTargetFieldsFromMiningSchema() throws Exception {
    final InputStream inputStream = getFileInputStream(NO_MODELNAME_SAMPLE_NAME);
    final PMML toPopulate = org.jpmml.model.PMMLUtil.unmarshal(inputStream);
    final Model model = toPopulate.getModels().get(0);
    List<MiningField> retrieved = KiePMMLUtil.getMiningTargetFields(model.getMiningSchema());
    assertNotNull(retrieved);
    assertEquals(1, retrieved.size());
    MiningField targetField = retrieved.get(0);
    assertEquals("car_location", targetField.getName().getValue());
    assertEquals("target", targetField.getUsageType().value());
}
Also used : MiningField(org.dmg.pmml.MiningField) FileUtils.getFileInputStream(org.kie.test.util.filesystem.FileUtils.getFileInputStream) InputStream(java.io.InputStream) Model(org.dmg.pmml.Model) MiningModel(org.dmg.pmml.mining.MiningModel) PMML(org.dmg.pmml.PMML) Test(org.junit.Test)

Example 43 with MiningField

use of org.dmg.pmml.MiningField in project drools by kiegroup.

the class KiePMMLUtilTest method populateMissingMiningTargetField.

@Test
public void populateMissingMiningTargetField() throws Exception {
    final InputStream inputStream = getFileInputStream(NO_TARGET_FIELD_SAMPLE);
    final PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(inputStream);
    final Model toPopulate = pmml.getModels().get(0);
    List<MiningField> miningTargetFields = getMiningTargetFields(toPopulate.getMiningSchema().getMiningFields());
    assertTrue(miningTargetFields.isEmpty());
    assertNull(toPopulate.getTargets().getTargets().get(0).getField());
    KiePMMLUtil.populateMissingMiningTargetField(toPopulate, pmml.getDataDictionary().getDataFields());
    miningTargetFields = getMiningTargetFields(toPopulate.getMiningSchema().getMiningFields());
    assertEquals(1, miningTargetFields.size());
    final MiningField targetField = miningTargetFields.get(0);
    assertTrue(pmml.getDataDictionary().getDataFields().stream().anyMatch(dataField -> dataField.getName().equals(targetField.getName())));
    assertEquals(targetField.getName(), toPopulate.getTargets().getTargets().get(0).getField());
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Model(org.dmg.pmml.Model) OutputField(org.dmg.pmml.OutputField) Random(java.util.Random) ResultFeature(org.dmg.pmml.ResultFeature) FieldName(org.dmg.pmml.FieldName) MODELNAME_TEMPLATE(org.kie.pmml.compiler.commons.utils.KiePMMLUtil.MODELNAME_TEMPLATE) OpType(org.dmg.pmml.OpType) Charset(java.nio.charset.Charset) MiningFunction(org.dmg.pmml.MiningFunction) MiningField(org.dmg.pmml.MiningField) MiningModel(org.dmg.pmml.mining.MiningModel) PMML(org.dmg.pmml.PMML) Targets(org.dmg.pmml.Targets) Assert.assertNotNull(org.junit.Assert.assertNotNull) DataType(org.dmg.pmml.DataType) Assert.assertTrue(org.junit.Assert.assertTrue) IOException(java.io.IOException) Test(org.junit.Test) Reader(java.io.Reader) InputStreamReader(java.io.InputStreamReader) Collectors(java.util.stream.Collectors) JAXBException(javax.xml.bind.JAXBException) Target(org.dmg.pmml.Target) StandardCharsets(java.nio.charset.StandardCharsets) SEGMENTID_TEMPLATE(org.kie.pmml.compiler.commons.utils.KiePMMLUtil.SEGMENTID_TEMPLATE) DataField(org.dmg.pmml.DataField) SEGMENTMODELNAME_TEMPLATE(org.kie.pmml.compiler.commons.utils.KiePMMLUtil.SEGMENTMODELNAME_TEMPLATE) List(java.util.List) Segment(org.dmg.pmml.mining.Segment) TARGETFIELD_TEMPLATE(org.kie.pmml.compiler.commons.utils.KiePMMLUtil.TARGETFIELD_TEMPLATE) Assert.assertNull(org.junit.Assert.assertNull) FileUtils.getFileInputStream(org.kie.test.util.filesystem.FileUtils.getFileInputStream) Assert.assertFalse(org.junit.Assert.assertFalse) SAXException(org.xml.sax.SAXException) KiePMMLUtil.getMiningTargetFields(org.kie.pmml.compiler.commons.utils.KiePMMLUtil.getMiningTargetFields) Optional(java.util.Optional) RandomStringUtils(org.apache.commons.lang3.RandomStringUtils) BufferedReader(java.io.BufferedReader) MathContext(org.dmg.pmml.MathContext) Assert.assertEquals(org.junit.Assert.assertEquals) InputStream(java.io.InputStream) MiningField(org.dmg.pmml.MiningField) FileUtils.getFileInputStream(org.kie.test.util.filesystem.FileUtils.getFileInputStream) InputStream(java.io.InputStream) Model(org.dmg.pmml.Model) MiningModel(org.dmg.pmml.mining.MiningModel) PMML(org.dmg.pmml.PMML) Test(org.junit.Test)

Example 44 with MiningField

use of org.dmg.pmml.MiningField in project drools by kiegroup.

the class KiePMMLUtilTest method getMiningTargetFieldsFromMiningFields.

@Test
public void getMiningTargetFieldsFromMiningFields() throws Exception {
    final InputStream inputStream = getFileInputStream(NO_MODELNAME_SAMPLE_NAME);
    final PMML toPopulate = org.jpmml.model.PMMLUtil.unmarshal(inputStream);
    final Model model = toPopulate.getModels().get(0);
    List<MiningField> retrieved = KiePMMLUtil.getMiningTargetFields(model.getMiningSchema().getMiningFields());
    assertNotNull(retrieved);
    assertEquals(1, retrieved.size());
    MiningField targetField = retrieved.get(0);
    assertEquals("car_location", targetField.getName().getValue());
    assertEquals("target", targetField.getUsageType().value());
}
Also used : MiningField(org.dmg.pmml.MiningField) FileUtils.getFileInputStream(org.kie.test.util.filesystem.FileUtils.getFileInputStream) InputStream(java.io.InputStream) Model(org.dmg.pmml.Model) MiningModel(org.dmg.pmml.mining.MiningModel) PMML(org.dmg.pmml.PMML) Test(org.junit.Test)

Example 45 with MiningField

use of org.dmg.pmml.MiningField in project drools by kiegroup.

the class KiePMMLDroolsModelFactoryUtilsTest method getKiePMMLModelCompilationUnit.

@Test
public void getKiePMMLModelCompilationUnit() {
    DataDictionary dataDictionary = new DataDictionary();
    String targetFieldString = "target.field";
    FieldName targetFieldName = FieldName.create(targetFieldString);
    dataDictionary.addDataFields(new DataField(targetFieldName, OpType.CONTINUOUS, DataType.DOUBLE));
    String modelName = "ModelName";
    TreeModel model = new TreeModel();
    model.setModelName(modelName);
    model.setMiningFunction(MiningFunction.CLASSIFICATION);
    MiningField targetMiningField = new MiningField(targetFieldName);
    targetMiningField.setUsageType(MiningField.UsageType.TARGET);
    MiningSchema miningSchema = new MiningSchema();
    miningSchema.addMiningFields(targetMiningField);
    model.setMiningSchema(miningSchema);
    Map<String, KiePMMLOriginalTypeGeneratedType> fieldTypeMap = new HashMap<>();
    fieldTypeMap.put(targetFieldString, new KiePMMLOriginalTypeGeneratedType(targetFieldString, getSanitizedClassName(targetFieldString)));
    String packageName = "net.test";
    PMML pmml = new PMML();
    pmml.setDataDictionary(dataDictionary);
    pmml.addModels(model);
    final CommonCompilationDTO<TreeModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(packageName, pmml, model, new HasClassLoaderMock());
    final DroolsCompilationDTO<TreeModel> droolsCompilationDTO = DroolsCompilationDTO.fromCompilationDTO(source, fieldTypeMap);
    CompilationUnit retrieved = KiePMMLDroolsModelFactoryUtils.getKiePMMLModelCompilationUnit(droolsCompilationDTO, TEMPLATE_SOURCE, TEMPLATE_CLASS_NAME);
    assertEquals(droolsCompilationDTO.getPackageName(), retrieved.getPackageDeclaration().get().getNameAsString());
    ConstructorDeclaration constructorDeclaration = retrieved.getClassByName(modelName).get().getDefaultConstructor().get();
    MINING_FUNCTION miningFunction = MINING_FUNCTION.CLASSIFICATION;
    PMML_MODEL pmmlModel = PMML_MODEL.byName(model.getClass().getSimpleName());
    Map<String, Expression> assignExpressionMap = new HashMap<>();
    assignExpressionMap.put("targetField", new StringLiteralExpr(targetFieldString));
    assignExpressionMap.put("miningFunction", new NameExpr(miningFunction.getClass().getName() + "." + miningFunction.name()));
    assignExpressionMap.put("pmmlMODEL", new NameExpr(pmmlModel.getClass().getName() + "." + pmmlModel.name()));
    String expectedKModulePackageName = getSanitizedPackageName(packageName + "." + modelName);
    assignExpressionMap.put("kModulePackageName", new StringLiteralExpr(expectedKModulePackageName));
    assertTrue(commonEvaluateAssignExpr(constructorDeclaration.getBody(), assignExpressionMap));
    // The last "1" is for
    int expectedMethodCallExprs = assignExpressionMap.size() + fieldTypeMap.size() + 1;
    // the super invocation
    commonEvaluateFieldTypeMap(constructorDeclaration.getBody(), fieldTypeMap, expectedMethodCallExprs);
}
Also used : CompilationUnit(com.github.javaparser.ast.CompilationUnit) MiningField(org.dmg.pmml.MiningField) HashMap(java.util.HashMap) StringLiteralExpr(com.github.javaparser.ast.expr.StringLiteralExpr) NameExpr(com.github.javaparser.ast.expr.NameExpr) DataDictionary(org.dmg.pmml.DataDictionary) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) KiePMMLOriginalTypeGeneratedType(org.kie.pmml.models.drools.tuples.KiePMMLOriginalTypeGeneratedType) TreeModel(org.dmg.pmml.tree.TreeModel) DataField(org.dmg.pmml.DataField) MiningSchema(org.dmg.pmml.MiningSchema) Expression(com.github.javaparser.ast.expr.Expression) ConstructorDeclaration(com.github.javaparser.ast.body.ConstructorDeclaration) PMML(org.dmg.pmml.PMML) FieldName(org.dmg.pmml.FieldName) PMML_MODEL(org.kie.pmml.api.enums.PMML_MODEL) MINING_FUNCTION(org.kie.pmml.api.enums.MINING_FUNCTION) Test(org.junit.Test)

Aggregations

MiningField (org.dmg.pmml.MiningField)59 DataField (org.dmg.pmml.DataField)40 Test (org.junit.Test)39 MiningSchema (org.dmg.pmml.MiningSchema)33 DataDictionary (org.dmg.pmml.DataDictionary)25 RegressionModel (org.dmg.pmml.regression.RegressionModel)24 Model (org.dmg.pmml.Model)22 PMMLModelTestUtils.getRandomDataField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomDataField)20 PMML (org.dmg.pmml.PMML)18 PMMLModelTestUtils.getRandomMiningField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomMiningField)18 PMMLModelTestUtils.getMiningField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getMiningField)17 PMMLModelTestUtils.getDataField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getDataField)16 OutputField (org.dmg.pmml.OutputField)15 CommonTestingUtils.getFieldsFromDataDictionary (org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary)15 FieldName (org.dmg.pmml.FieldName)12 Target (org.dmg.pmml.Target)11 Targets (org.dmg.pmml.Targets)11 OP_TYPE (org.kie.pmml.api.enums.OP_TYPE)11 HashMap (java.util.HashMap)10 List (java.util.List)10