Search in sources :

Example 76 with DataField

use of org.dmg.pmml.DataField in project drools by kiegroup.

the class KiePMMLClassificationTableFactoryTest method getClassificationTableBuilder.

@Test
public void getClassificationTableBuilder() {
    RegressionTable regressionTableProf = getRegressionTable(3.5, "professional");
    RegressionTable regressionTableCler = getRegressionTable(27.4, "clerical");
    OutputField outputFieldCat = getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
    OutputField outputFieldNum = getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
    OutputField outputFieldPrev = getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
    String targetField = "targetField";
    DataField dataField = new DataField();
    dataField.setName(FieldName.create(targetField));
    dataField.setOpType(OpType.CATEGORICAL);
    DataDictionary dataDictionary = new DataDictionary();
    dataDictionary.addDataFields(dataField);
    RegressionModel regressionModel = new RegressionModel();
    regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
    regressionModel.addRegressionTables(regressionTableProf, regressionTableCler);
    regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
    Output output = new Output();
    output.addOutputFields(outputFieldCat, outputFieldNum, outputFieldPrev);
    regressionModel.setOutput(output);
    MiningField miningField = new MiningField();
    miningField.setUsageType(MiningField.UsageType.TARGET);
    miningField.setName(dataField.getName());
    MiningSchema miningSchema = new MiningSchema();
    miningSchema.addMiningFields(miningField);
    regressionModel.setMiningSchema(miningSchema);
    PMML pmml = new PMML();
    pmml.setDataDictionary(dataDictionary);
    pmml.addModels(regressionModel);
    final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
    final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
    final LinkedHashMap<String, KiePMMLTableSourceCategory> regressionTablesMap = new LinkedHashMap<>();
    regressionModel.getRegressionTables().forEach(regressionTable -> {
        String key = compilationDTO.getPackageName() + "." + regressionTable.getTargetCategory().toString().toUpperCase();
        KiePMMLTableSourceCategory value = new KiePMMLTableSourceCategory("", regressionTable.getTargetCategory().toString());
        regressionTablesMap.put(key, value);
    });
    Map.Entry<String, String> retrieved = KiePMMLClassificationTableFactory.getClassificationTableBuilder(compilationDTO, regressionTablesMap);
    assertNotNull(retrieved);
}
Also used : MiningField(org.dmg.pmml.MiningField) DataDictionary(org.dmg.pmml.DataDictionary) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) RegressionTable(org.dmg.pmml.regression.RegressionTable) RegressionModel(org.dmg.pmml.regression.RegressionModel) LinkedHashMap(java.util.LinkedHashMap) DataField(org.dmg.pmml.DataField) MiningSchema(org.dmg.pmml.MiningSchema) Output(org.dmg.pmml.Output) KiePMMLTableSourceCategory(org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory) OutputField(org.dmg.pmml.OutputField) PMML(org.dmg.pmml.PMML) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map) Test(org.junit.Test)

Example 77 with DataField

use of org.dmg.pmml.DataField in project drools by kiegroup.

the class KiePMMLDroolsModelFactoryUtilsTest method getKiePMMLModelCompilationUnit.

@Test
public void getKiePMMLModelCompilationUnit() {
    DataDictionary dataDictionary = new DataDictionary();
    String targetFieldString = "target.field";
    FieldName targetFieldName = FieldName.create(targetFieldString);
    dataDictionary.addDataFields(new DataField(targetFieldName, OpType.CONTINUOUS, DataType.DOUBLE));
    String modelName = "ModelName";
    TreeModel model = new TreeModel();
    model.setModelName(modelName);
    model.setMiningFunction(MiningFunction.CLASSIFICATION);
    MiningField targetMiningField = new MiningField(targetFieldName);
    targetMiningField.setUsageType(MiningField.UsageType.TARGET);
    MiningSchema miningSchema = new MiningSchema();
    miningSchema.addMiningFields(targetMiningField);
    model.setMiningSchema(miningSchema);
    Map<String, KiePMMLOriginalTypeGeneratedType> fieldTypeMap = new HashMap<>();
    fieldTypeMap.put(targetFieldString, new KiePMMLOriginalTypeGeneratedType(targetFieldString, getSanitizedClassName(targetFieldString)));
    String packageName = "net.test";
    PMML pmml = new PMML();
    pmml.setDataDictionary(dataDictionary);
    pmml.addModels(model);
    final CommonCompilationDTO<TreeModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(packageName, pmml, model, new HasClassLoaderMock());
    final DroolsCompilationDTO<TreeModel> droolsCompilationDTO = DroolsCompilationDTO.fromCompilationDTO(source, fieldTypeMap);
    CompilationUnit retrieved = KiePMMLDroolsModelFactoryUtils.getKiePMMLModelCompilationUnit(droolsCompilationDTO, TEMPLATE_SOURCE, TEMPLATE_CLASS_NAME);
    assertEquals(droolsCompilationDTO.getPackageName(), retrieved.getPackageDeclaration().get().getNameAsString());
    ConstructorDeclaration constructorDeclaration = retrieved.getClassByName(modelName).get().getDefaultConstructor().get();
    MINING_FUNCTION miningFunction = MINING_FUNCTION.CLASSIFICATION;
    PMML_MODEL pmmlModel = PMML_MODEL.byName(model.getClass().getSimpleName());
    Map<String, Expression> assignExpressionMap = new HashMap<>();
    assignExpressionMap.put("targetField", new StringLiteralExpr(targetFieldString));
    assignExpressionMap.put("miningFunction", new NameExpr(miningFunction.getClass().getName() + "." + miningFunction.name()));
    assignExpressionMap.put("pmmlMODEL", new NameExpr(pmmlModel.getClass().getName() + "." + pmmlModel.name()));
    String expectedKModulePackageName = getSanitizedPackageName(packageName + "." + modelName);
    assignExpressionMap.put("kModulePackageName", new StringLiteralExpr(expectedKModulePackageName));
    assertTrue(commonEvaluateAssignExpr(constructorDeclaration.getBody(), assignExpressionMap));
    // The last "1" is for
    int expectedMethodCallExprs = assignExpressionMap.size() + fieldTypeMap.size() + 1;
    // the super invocation
    commonEvaluateFieldTypeMap(constructorDeclaration.getBody(), fieldTypeMap, expectedMethodCallExprs);
}
Also used : CompilationUnit(com.github.javaparser.ast.CompilationUnit) MiningField(org.dmg.pmml.MiningField) HashMap(java.util.HashMap) StringLiteralExpr(com.github.javaparser.ast.expr.StringLiteralExpr) NameExpr(com.github.javaparser.ast.expr.NameExpr) DataDictionary(org.dmg.pmml.DataDictionary) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) KiePMMLOriginalTypeGeneratedType(org.kie.pmml.models.drools.tuples.KiePMMLOriginalTypeGeneratedType) TreeModel(org.dmg.pmml.tree.TreeModel) DataField(org.dmg.pmml.DataField) MiningSchema(org.dmg.pmml.MiningSchema) Expression(com.github.javaparser.ast.expr.Expression) ConstructorDeclaration(com.github.javaparser.ast.body.ConstructorDeclaration) PMML(org.dmg.pmml.PMML) FieldName(org.dmg.pmml.FieldName) PMML_MODEL(org.kie.pmml.api.enums.PMML_MODEL) MINING_FUNCTION(org.kie.pmml.api.enums.MINING_FUNCTION) Test(org.junit.Test)

Example 78 with DataField

use of org.dmg.pmml.DataField in project drools by kiegroup.

the class KiePMMLRegressionTableFactoryTest method getRegressionTableBuilders.

@Test
public void getRegressionTableBuilders() {
    regressionTable = getRegressionTable(3.5, "professional");
    RegressionModel regressionModel = new RegressionModel();
    regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
    regressionModel.addRegressionTables(regressionTable);
    regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
    String targetField = "targetField";
    DataField dataField = new DataField();
    dataField.setName(FieldName.create(targetField));
    dataField.setOpType(OpType.CATEGORICAL);
    DataDictionary dataDictionary = new DataDictionary();
    dataDictionary.addDataFields(dataField);
    MiningField miningField = new MiningField();
    miningField.setUsageType(MiningField.UsageType.TARGET);
    miningField.setName(dataField.getName());
    MiningSchema miningSchema = new MiningSchema();
    miningSchema.addMiningFields(miningField);
    regressionModel.setMiningSchema(miningSchema);
    PMML pmml = new PMML();
    pmml.setDataDictionary(dataDictionary);
    pmml.addModels(regressionModel);
    final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
    final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, new ArrayList<>(), regressionModel.getNormalizationMethod());
    Map<String, KiePMMLTableSourceCategory> retrieved = KiePMMLRegressionTableFactory.getRegressionTableBuilders(compilationDTO);
    assertNotNull(retrieved);
    retrieved.values().forEach(kiePMMLTableSourceCategory -> commonValidateKiePMMLRegressionTable(kiePMMLTableSourceCategory.getSource()));
}
Also used : MiningField(org.dmg.pmml.MiningField) DataDictionary(org.dmg.pmml.DataDictionary) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) RegressionModel(org.dmg.pmml.regression.RegressionModel) DataField(org.dmg.pmml.DataField) MiningSchema(org.dmg.pmml.MiningSchema) KiePMMLTableSourceCategory(org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory) PMML(org.dmg.pmml.PMML) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) Test(org.junit.Test)

Example 79 with DataField

use of org.dmg.pmml.DataField in project drools by kiegroup.

the class KiePMMLRegressionTableFactoryTest method getRegressionTables.

@Test
public void getRegressionTables() {
    regressionTable = getRegressionTable(3.5, "professional");
    RegressionTable regressionTable2 = getRegressionTable(3.9, "hobby");
    RegressionModel regressionModel = new RegressionModel();
    regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
    regressionModel.addRegressionTables(regressionTable, regressionTable2);
    regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
    String targetField = "targetField";
    DataField dataField = new DataField();
    dataField.setName(FieldName.create(targetField));
    dataField.setOpType(OpType.CATEGORICAL);
    DataDictionary dataDictionary = new DataDictionary();
    dataDictionary.addDataFields(dataField);
    MiningField miningField = new MiningField();
    miningField.setUsageType(MiningField.UsageType.TARGET);
    miningField.setName(dataField.getName());
    MiningSchema miningSchema = new MiningSchema();
    miningSchema.addMiningFields(miningField);
    regressionModel.setMiningSchema(miningSchema);
    PMML pmml = new PMML();
    pmml.setDataDictionary(dataDictionary);
    pmml.addModels(regressionModel);
    final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
    final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
    Map<String, KiePMMLRegressionTable> retrieved = KiePMMLRegressionTableFactory.getRegressionTables(compilationDTO);
    assertNotNull(retrieved);
    assertEquals(regressionModel.getRegressionTables().size(), retrieved.size());
    regressionModel.getRegressionTables().forEach(regrTabl -> {
        assertTrue(retrieved.containsKey(regrTabl.getTargetCategory().toString()));
        commonEvaluateRegressionTable(retrieved.get(regrTabl.getTargetCategory().toString()), regrTabl);
    });
}
Also used : MiningField(org.dmg.pmml.MiningField) KiePMMLRegressionTable(org.kie.pmml.models.regression.model.KiePMMLRegressionTable) DataDictionary(org.dmg.pmml.DataDictionary) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) RegressionTable(org.dmg.pmml.regression.RegressionTable) KiePMMLRegressionTable(org.kie.pmml.models.regression.model.KiePMMLRegressionTable) RegressionModel(org.dmg.pmml.regression.RegressionModel) DataField(org.dmg.pmml.DataField) MiningSchema(org.dmg.pmml.MiningSchema) PMML(org.dmg.pmml.PMML) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) Test(org.junit.Test)

Example 80 with DataField

use of org.dmg.pmml.DataField in project drools by kiegroup.

the class KiePMMLRegressionTableFactoryTest method setStaticGetter.

@Test
public void setStaticGetter() throws IOException {
    regressionTable = getRegressionTable(3.5, "professional");
    RegressionModel regressionModel = new RegressionModel();
    regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
    regressionModel.addRegressionTables(regressionTable);
    regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
    String targetField = "targetField";
    DataField dataField = new DataField();
    dataField.setName(FieldName.create(targetField));
    dataField.setOpType(OpType.CATEGORICAL);
    DataDictionary dataDictionary = new DataDictionary();
    dataDictionary.addDataFields(dataField);
    MiningField miningField = new MiningField();
    miningField.setUsageType(MiningField.UsageType.TARGET);
    miningField.setName(dataField.getName());
    MiningSchema miningSchema = new MiningSchema();
    miningSchema.addMiningFields(miningField);
    regressionModel.setMiningSchema(miningSchema);
    PMML pmml = new PMML();
    pmml.setDataDictionary(dataDictionary);
    pmml.addModels(regressionModel);
    String variableName = "variableName";
    final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
    final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, new ArrayList<>(), regressionModel.getNormalizationMethod());
    final MethodDeclaration staticGetterMethod = STATIC_GETTER_METHOD.clone();
    KiePMMLRegressionTableFactory.setStaticGetter(regressionTable, compilationDTO, staticGetterMethod, variableName);
    String text = getFileContent(TEST_06_SOURCE);
    MethodDeclaration expected = JavaParserUtils.parseMethod(text);
    assertEquals(expected.toString(), staticGetterMethod.toString());
    assertTrue(JavaParserUtils.equalsNode(expected, staticGetterMethod));
    List<Class<?>> imports = Arrays.asList(AtomicReference.class, Collections.class, Arrays.class, List.class, Map.class, KiePMMLRegressionTable.class, SerializableFunction.class);
    commonValidateCompilationWithImports(staticGetterMethod, imports);
}
Also used : MiningField(org.dmg.pmml.MiningField) MethodDeclaration(com.github.javaparser.ast.body.MethodDeclaration) DataDictionary(org.dmg.pmml.DataDictionary) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) RegressionModel(org.dmg.pmml.regression.RegressionModel) DataField(org.dmg.pmml.DataField) MiningSchema(org.dmg.pmml.MiningSchema) PMML(org.dmg.pmml.PMML) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) BeforeClass(org.junit.BeforeClass) Test(org.junit.Test)

Aggregations

DataField (org.dmg.pmml.DataField)101 Test (org.junit.Test)51 DataDictionary (org.dmg.pmml.DataDictionary)42 MiningField (org.dmg.pmml.MiningField)42 MiningSchema (org.dmg.pmml.MiningSchema)30 PMMLModelTestUtils.getRandomDataField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomDataField)28 RegressionModel (org.dmg.pmml.regression.RegressionModel)27 CommonTestingUtils.getFieldsFromDataDictionary (org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary)27 FieldName (org.dmg.pmml.FieldName)24 Model (org.dmg.pmml.Model)24 PMMLModelTestUtils.getDataField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getDataField)22 DataType (org.dmg.pmml.DataType)19 OutputField (org.dmg.pmml.OutputField)19 PMMLModelTestUtils.getRandomMiningField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomMiningField)19 PMMLModelTestUtils.getMiningField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getMiningField)18 ArrayList (java.util.ArrayList)17 List (java.util.List)17 PMML (org.dmg.pmml.PMML)17 Collectors (java.util.stream.Collectors)16 OpType (org.dmg.pmml.OpType)15