Search in sources :

Example 46 with DataDictionary

use of org.dmg.pmml.DataDictionary in project drools by kiegroup.

the class KiePMMLRegressionTableFactoryTest method getRegressionTables.

@Test
public void getRegressionTables() {
    regressionTable = getRegressionTable(3.5, "professional");
    RegressionTable regressionTable2 = getRegressionTable(3.9, "hobby");
    RegressionModel regressionModel = new RegressionModel();
    regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
    regressionModel.addRegressionTables(regressionTable, regressionTable2);
    regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
    String targetField = "targetField";
    DataField dataField = new DataField();
    dataField.setName(FieldName.create(targetField));
    dataField.setOpType(OpType.CATEGORICAL);
    DataDictionary dataDictionary = new DataDictionary();
    dataDictionary.addDataFields(dataField);
    MiningField miningField = new MiningField();
    miningField.setUsageType(MiningField.UsageType.TARGET);
    miningField.setName(dataField.getName());
    MiningSchema miningSchema = new MiningSchema();
    miningSchema.addMiningFields(miningField);
    regressionModel.setMiningSchema(miningSchema);
    PMML pmml = new PMML();
    pmml.setDataDictionary(dataDictionary);
    pmml.addModels(regressionModel);
    final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
    final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
    Map<String, KiePMMLRegressionTable> retrieved = KiePMMLRegressionTableFactory.getRegressionTables(compilationDTO);
    assertNotNull(retrieved);
    assertEquals(regressionModel.getRegressionTables().size(), retrieved.size());
    regressionModel.getRegressionTables().forEach(regrTabl -> {
        assertTrue(retrieved.containsKey(regrTabl.getTargetCategory().toString()));
        commonEvaluateRegressionTable(retrieved.get(regrTabl.getTargetCategory().toString()), regrTabl);
    });
}
Also used : MiningField(org.dmg.pmml.MiningField) KiePMMLRegressionTable(org.kie.pmml.models.regression.model.KiePMMLRegressionTable) DataDictionary(org.dmg.pmml.DataDictionary) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) RegressionTable(org.dmg.pmml.regression.RegressionTable) KiePMMLRegressionTable(org.kie.pmml.models.regression.model.KiePMMLRegressionTable) RegressionModel(org.dmg.pmml.regression.RegressionModel) DataField(org.dmg.pmml.DataField) MiningSchema(org.dmg.pmml.MiningSchema) PMML(org.dmg.pmml.PMML) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) Test(org.junit.Test)

Example 47 with DataDictionary

use of org.dmg.pmml.DataDictionary in project drools by kiegroup.

the class KiePMMLRegressionTableFactoryTest method setStaticGetter.

@Test
public void setStaticGetter() throws IOException {
    regressionTable = getRegressionTable(3.5, "professional");
    RegressionModel regressionModel = new RegressionModel();
    regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
    regressionModel.addRegressionTables(regressionTable);
    regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
    String targetField = "targetField";
    DataField dataField = new DataField();
    dataField.setName(FieldName.create(targetField));
    dataField.setOpType(OpType.CATEGORICAL);
    DataDictionary dataDictionary = new DataDictionary();
    dataDictionary.addDataFields(dataField);
    MiningField miningField = new MiningField();
    miningField.setUsageType(MiningField.UsageType.TARGET);
    miningField.setName(dataField.getName());
    MiningSchema miningSchema = new MiningSchema();
    miningSchema.addMiningFields(miningField);
    regressionModel.setMiningSchema(miningSchema);
    PMML pmml = new PMML();
    pmml.setDataDictionary(dataDictionary);
    pmml.addModels(regressionModel);
    String variableName = "variableName";
    final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
    final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, new ArrayList<>(), regressionModel.getNormalizationMethod());
    final MethodDeclaration staticGetterMethod = STATIC_GETTER_METHOD.clone();
    KiePMMLRegressionTableFactory.setStaticGetter(regressionTable, compilationDTO, staticGetterMethod, variableName);
    String text = getFileContent(TEST_06_SOURCE);
    MethodDeclaration expected = JavaParserUtils.parseMethod(text);
    assertEquals(expected.toString(), staticGetterMethod.toString());
    assertTrue(JavaParserUtils.equalsNode(expected, staticGetterMethod));
    List<Class<?>> imports = Arrays.asList(AtomicReference.class, Collections.class, Arrays.class, List.class, Map.class, KiePMMLRegressionTable.class, SerializableFunction.class);
    commonValidateCompilationWithImports(staticGetterMethod, imports);
}
Also used : MiningField(org.dmg.pmml.MiningField) MethodDeclaration(com.github.javaparser.ast.body.MethodDeclaration) DataDictionary(org.dmg.pmml.DataDictionary) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) RegressionModel(org.dmg.pmml.regression.RegressionModel) DataField(org.dmg.pmml.DataField) MiningSchema(org.dmg.pmml.MiningSchema) PMML(org.dmg.pmml.PMML) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) BeforeClass(org.junit.BeforeClass) Test(org.junit.Test)

Example 48 with DataDictionary

use of org.dmg.pmml.DataDictionary in project shifu by ShifuML.

the class PMMLAdapterCommonUtil method getDataDicHeaders.

/**
 * get the header names from the PMML data dictionary
 *
 * @param pmml
 *            the pmml model
 * @return headers
 */
public static String[] getDataDicHeaders(final PMML pmml) {
    DataDictionary dictionary = pmml.getDataDictionary();
    List<DataField> fields = dictionary.getDataFields();
    int len = fields.size();
    String[] headers = new String[len];
    for (int i = 0; i < len; i++) {
        headers[i] = fields.get(i).getName().getValue();
    }
    return headers;
}
Also used : DataField(org.dmg.pmml.DataField) DataDictionary(org.dmg.pmml.DataDictionary)

Aggregations

DataDictionary (org.dmg.pmml.DataDictionary)48 DataField (org.dmg.pmml.DataField)41 Test (org.junit.Test)41 CommonTestingUtils.getFieldsFromDataDictionary (org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary)30 MiningSchema (org.dmg.pmml.MiningSchema)28 MiningField (org.dmg.pmml.MiningField)27 RegressionModel (org.dmg.pmml.regression.RegressionModel)27 PMMLModelTestUtils.getDataField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getDataField)21 PMMLModelTestUtils.getRandomDataField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomDataField)21 Model (org.dmg.pmml.Model)19 PMMLModelTestUtils.getMiningField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getMiningField)17 PMMLModelTestUtils.getRandomMiningField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomMiningField)17 PMML (org.dmg.pmml.PMML)12 OutputField (org.dmg.pmml.OutputField)11 Collectors (java.util.stream.Collectors)10 Assert.assertTrue (org.junit.Assert.assertTrue)10 DATA_TYPE (org.kie.pmml.api.enums.DATA_TYPE)10 OP_TYPE (org.kie.pmml.api.enums.OP_TYPE)10 HasClassLoaderMock (org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock)10 Arrays (java.util.Arrays)9