Search in sources :

Example 1 with KiePMMLRegressionTable

use of org.kie.pmml.models.regression.model.KiePMMLRegressionTable in project drools by kiegroup.

the class KiePMMLRegressionTableFactoryTest method getRegressionTable.

@Test
public void getRegressionTable() {
    regressionTable = getRegressionTable(3.5, "professional");
    RegressionModel regressionModel = new RegressionModel();
    regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
    regressionModel.addRegressionTables(regressionTable);
    regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
    String targetField = "targetField";
    DataField dataField = new DataField();
    dataField.setName(FieldName.create(targetField));
    dataField.setOpType(OpType.CATEGORICAL);
    DataDictionary dataDictionary = new DataDictionary();
    dataDictionary.addDataFields(dataField);
    MiningField miningField = new MiningField();
    miningField.setUsageType(MiningField.UsageType.TARGET);
    miningField.setName(dataField.getName());
    MiningSchema miningSchema = new MiningSchema();
    miningSchema.addMiningFields(miningField);
    regressionModel.setMiningSchema(miningSchema);
    PMML pmml = new PMML();
    pmml.setDataDictionary(dataDictionary);
    pmml.addModels(regressionModel);
    final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
    final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, new ArrayList<>(), regressionModel.getNormalizationMethod());
    KiePMMLRegressionTable retrieved = KiePMMLRegressionTableFactory.getRegressionTable(regressionTable, compilationDTO);
    assertNotNull(retrieved);
    commonEvaluateRegressionTable(retrieved, regressionTable);
}
Also used : MiningField(org.dmg.pmml.MiningField) DataField(org.dmg.pmml.DataField) MiningSchema(org.dmg.pmml.MiningSchema) KiePMMLRegressionTable(org.kie.pmml.models.regression.model.KiePMMLRegressionTable) PMML(org.dmg.pmml.PMML) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) DataDictionary(org.dmg.pmml.DataDictionary) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) RegressionModel(org.dmg.pmml.regression.RegressionModel) Test(org.junit.Test)

Example 2 with KiePMMLRegressionTable

use of org.kie.pmml.models.regression.model.KiePMMLRegressionTable in project drools by kiegroup.

the class KiePMMLRegressionTableFactory method getRegressionTables.

// KiePMMLRegressionTable instantiation
public static LinkedHashMap<String, KiePMMLRegressionTable> getRegressionTables(final RegressionCompilationDTO compilationDTO) {
    logger.trace("getRegressionTables {}", compilationDTO.getRegressionTables());
    LinkedHashMap<String, KiePMMLRegressionTable> toReturn = new LinkedHashMap<>();
    for (RegressionTable regressionTable : compilationDTO.getRegressionTables()) {
        final KiePMMLRegressionTable kiePMMLRegressionTable = getRegressionTable(regressionTable, compilationDTO);
        String targetCategory = regressionTable.getTargetCategory() != null ? regressionTable.getTargetCategory().toString() : "";
        toReturn.put(targetCategory, kiePMMLRegressionTable);
    }
    return toReturn;
}
Also used : KiePMMLRegressionTable(org.kie.pmml.models.regression.model.KiePMMLRegressionTable) RegressionTable(org.dmg.pmml.regression.RegressionTable) KiePMMLRegressionTable(org.kie.pmml.models.regression.model.KiePMMLRegressionTable) LinkedHashMap(java.util.LinkedHashMap)

Example 3 with KiePMMLRegressionTable

use of org.kie.pmml.models.regression.model.KiePMMLRegressionTable in project drools by kiegroup.

the class KiePMMLRegressionModelFactoryTest method evaluateCategoricalRegressionTable.

private void evaluateCategoricalRegressionTable(KiePMMLClassificationTable regressionTable) {
    assertEquals(REGRESSION_NORMALIZATION_METHOD.byName(regressionModel.getNormalizationMethod().value()), regressionTable.getRegressionNormalizationMethod());
    assertEquals(OP_TYPE.CATEGORICAL, regressionTable.getOpType());
    final Map<String, KiePMMLRegressionTable> categoryTableMap = regressionTable.getCategoryTableMap();
    for (RegressionTable originalRegressionTable : regressionTables) {
        assertTrue(categoryTableMap.containsKey(originalRegressionTable.getTargetCategory().toString()));
        evaluateRegressionTable(categoryTableMap.get(originalRegressionTable.getTargetCategory().toString()), originalRegressionTable);
    }
}
Also used : KiePMMLRegressionTable(org.kie.pmml.models.regression.model.KiePMMLRegressionTable) RegressionTable(org.dmg.pmml.regression.RegressionTable) KiePMMLRegressionTable(org.kie.pmml.models.regression.model.KiePMMLRegressionTable) PMMLModelTestUtils.getRegressionTable(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRegressionTable)

Example 4 with KiePMMLRegressionTable

use of org.kie.pmml.models.regression.model.KiePMMLRegressionTable in project drools by kiegroup.

the class KiePMMLRegressionTableFactoryTest method getRegressionTables.

@Test
public void getRegressionTables() {
    regressionTable = getRegressionTable(3.5, "professional");
    RegressionTable regressionTable2 = getRegressionTable(3.9, "hobby");
    RegressionModel regressionModel = new RegressionModel();
    regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
    regressionModel.addRegressionTables(regressionTable, regressionTable2);
    regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
    String targetField = "targetField";
    DataField dataField = new DataField();
    dataField.setName(FieldName.create(targetField));
    dataField.setOpType(OpType.CATEGORICAL);
    DataDictionary dataDictionary = new DataDictionary();
    dataDictionary.addDataFields(dataField);
    MiningField miningField = new MiningField();
    miningField.setUsageType(MiningField.UsageType.TARGET);
    miningField.setName(dataField.getName());
    MiningSchema miningSchema = new MiningSchema();
    miningSchema.addMiningFields(miningField);
    regressionModel.setMiningSchema(miningSchema);
    PMML pmml = new PMML();
    pmml.setDataDictionary(dataDictionary);
    pmml.addModels(regressionModel);
    final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
    final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
    Map<String, KiePMMLRegressionTable> retrieved = KiePMMLRegressionTableFactory.getRegressionTables(compilationDTO);
    assertNotNull(retrieved);
    assertEquals(regressionModel.getRegressionTables().size(), retrieved.size());
    regressionModel.getRegressionTables().forEach(regrTabl -> {
        assertTrue(retrieved.containsKey(regrTabl.getTargetCategory().toString()));
        commonEvaluateRegressionTable(retrieved.get(regrTabl.getTargetCategory().toString()), regrTabl);
    });
}
Also used : MiningField(org.dmg.pmml.MiningField) KiePMMLRegressionTable(org.kie.pmml.models.regression.model.KiePMMLRegressionTable) DataDictionary(org.dmg.pmml.DataDictionary) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) RegressionTable(org.dmg.pmml.regression.RegressionTable) KiePMMLRegressionTable(org.kie.pmml.models.regression.model.KiePMMLRegressionTable) RegressionModel(org.dmg.pmml.regression.RegressionModel) DataField(org.dmg.pmml.DataField) MiningSchema(org.dmg.pmml.MiningSchema) PMML(org.dmg.pmml.PMML) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) Test(org.junit.Test)

Aggregations

KiePMMLRegressionTable (org.kie.pmml.models.regression.model.KiePMMLRegressionTable)4 RegressionTable (org.dmg.pmml.regression.RegressionTable)3 DataDictionary (org.dmg.pmml.DataDictionary)2 DataField (org.dmg.pmml.DataField)2 MiningField (org.dmg.pmml.MiningField)2 MiningSchema (org.dmg.pmml.MiningSchema)2 PMML (org.dmg.pmml.PMML)2 RegressionModel (org.dmg.pmml.regression.RegressionModel)2 Test (org.junit.Test)2 HasClassLoaderMock (org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock)2 RegressionCompilationDTO (org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO)2 LinkedHashMap (java.util.LinkedHashMap)1 PMMLModelTestUtils.getRegressionTable (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRegressionTable)1