Search in sources :

Example 1 with AbstractKiePMMLTable

use of org.kie.pmml.models.regression.model.AbstractKiePMMLTable in project drools by kiegroup.

the class KiePMMLRegressionModelFactory method getRegressionTables.

// not-public KiePMMLRegressionModel instantiation
static Map<String, AbstractKiePMMLTable> getRegressionTables(final RegressionCompilationDTO compilationDTO) {
    Map<String, AbstractKiePMMLTable> toReturn = new HashMap<>();
    if (compilationDTO.isRegression()) {
        final List<RegressionTable> regressionTables = Collections.singletonList(compilationDTO.getModel().getRegressionTables().get(0));
        final RegressionCompilationDTO regressionCompilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(compilationDTO, regressionTables, compilationDTO.getModel().getNormalizationMethod());
        toReturn.putAll(KiePMMLRegressionTableFactory.getRegressionTables(regressionCompilationDTO));
    } else {
        final List<RegressionTable> regressionTables = compilationDTO.getModel().getRegressionTables();
        final RegressionCompilationDTO regressionCompilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(compilationDTO, regressionTables, RegressionModel.NormalizationMethod.NONE);
        KiePMMLClassificationTable kiePMMLClassificationTable = KiePMMLClassificationTableFactory.getClassificationTable(regressionCompilationDTO);
        toReturn.put(kiePMMLClassificationTable.getName(), kiePMMLClassificationTable);
    }
    return toReturn;
}
Also used : HashMap(java.util.HashMap) AbstractKiePMMLTable(org.kie.pmml.models.regression.model.AbstractKiePMMLTable) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) KiePMMLClassificationTable(org.kie.pmml.models.regression.model.KiePMMLClassificationTable) RegressionTable(org.dmg.pmml.regression.RegressionTable)

Example 2 with AbstractKiePMMLTable

use of org.kie.pmml.models.regression.model.AbstractKiePMMLTable in project drools by kiegroup.

the class KiePMMLRegressionModelFactory method getKiePMMLRegressionModelClasses.

// KiePMMLRegressionModel instantiation
public static KiePMMLRegressionModel getKiePMMLRegressionModelClasses(final RegressionCompilationDTO compilationDTO) throws IOException, IllegalAccessException, InstantiationException {
    logger.trace("getKiePMMLRegressionModelClasses {} {}", compilationDTO.getFields(), compilationDTO.getModel());
    Map<String, AbstractKiePMMLTable> regressionTablesMap = getRegressionTables(compilationDTO);
    try {
        AbstractKiePMMLTable nestedTable = regressionTablesMap.size() == 1 ? regressionTablesMap.values().iterator().next() : regressionTablesMap.values().stream().filter(KiePMMLClassificationTable.class::isInstance).findFirst().orElseThrow(() -> new KiePMMLException("Failed to find expected " + KiePMMLClassificationTable.class.getSimpleName()));
        return KiePMMLRegressionModel.builder(compilationDTO.getModelName(), compilationDTO.getMINING_FUNCTION()).withAbstractKiePMMLTable(nestedTable).withTargetField(compilationDTO.getTargetFieldName()).withMiningFields(compilationDTO.getKieMiningFields()).withOutputFields(compilationDTO.getKieOutputFields()).withKiePMMLMiningFields(compilationDTO.getKiePMMLMiningFields()).withKiePMMLOutputFields(compilationDTO.getKiePMMLOutputFields()).withKiePMMLTargets(compilationDTO.getKiePMMLTargetFields()).withKiePMMLTransformationDictionary(compilationDTO.getKiePMMLTransformationDictionary()).withKiePMMLLocalTransformations(compilationDTO.getKiePMMLLocalTransformations()).build();
    } catch (Exception e) {
        throw new KiePMMLException(e);
    }
}
Also used : AbstractKiePMMLTable(org.kie.pmml.models.regression.model.AbstractKiePMMLTable) KiePMMLException(org.kie.pmml.api.exceptions.KiePMMLException) IOException(java.io.IOException) KiePMMLException(org.kie.pmml.api.exceptions.KiePMMLException)

Example 3 with AbstractKiePMMLTable

use of org.kie.pmml.models.regression.model.AbstractKiePMMLTable in project drools by kiegroup.

the class KiePMMLRegressionModelFactoryTest method getKiePMMLRegressionModelClasses.

@Test
public void getKiePMMLRegressionModelClasses() throws IOException, IllegalAccessException, InstantiationException {
    final CompilationDTO<RegressionModel> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
    KiePMMLRegressionModel retrieved = KiePMMLRegressionModelFactory.getKiePMMLRegressionModelClasses(RegressionCompilationDTO.fromCompilationDTO(compilationDTO));
    assertNotNull(retrieved);
    assertEquals(regressionModel.getModelName(), retrieved.getName());
    assertEquals(MINING_FUNCTION.byName(regressionModel.getMiningFunction().value()), retrieved.getMiningFunction());
    assertEquals(miningFields.get(0).getName().getValue(), retrieved.getTargetField());
    final AbstractKiePMMLTable regressionTable = retrieved.getRegressionTable();
    assertNotNull(regressionTable);
    assertTrue(regressionTable instanceof KiePMMLClassificationTable);
    evaluateCategoricalRegressionTable((KiePMMLClassificationTable) regressionTable);
}
Also used : KiePMMLRegressionModel(org.kie.pmml.models.regression.model.KiePMMLRegressionModel) AbstractKiePMMLTable(org.kie.pmml.models.regression.model.AbstractKiePMMLTable) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) KiePMMLClassificationTable(org.kie.pmml.models.regression.model.KiePMMLClassificationTable) PMMLModelTestUtils.getRegressionModel(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRegressionModel) RegressionModel(org.dmg.pmml.regression.RegressionModel) KiePMMLRegressionModel(org.kie.pmml.models.regression.model.KiePMMLRegressionModel) Test(org.junit.Test)

Aggregations

AbstractKiePMMLTable (org.kie.pmml.models.regression.model.AbstractKiePMMLTable)3 KiePMMLClassificationTable (org.kie.pmml.models.regression.model.KiePMMLClassificationTable)2 IOException (java.io.IOException)1 HashMap (java.util.HashMap)1 RegressionModel (org.dmg.pmml.regression.RegressionModel)1 RegressionTable (org.dmg.pmml.regression.RegressionTable)1 Test (org.junit.Test)1 KiePMMLException (org.kie.pmml.api.exceptions.KiePMMLException)1 PMMLModelTestUtils.getRegressionModel (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRegressionModel)1 HasClassLoaderMock (org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock)1 RegressionCompilationDTO (org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO)1 KiePMMLRegressionModel (org.kie.pmml.models.regression.model.KiePMMLRegressionModel)1