Search in sources :

Example 1 with DataDictionary

use of org.dmg.pmml.DataDictionary in project shifu by ShifuML.

the class PMMLConstructorFactory method produce.

public static PMMLTranslator produce(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, boolean isConcise, boolean isOutBaggingToOne) {
    AbstractPmmlElementCreator<Model> modelCreator = null;
    AbstractSpecifCreator specifCreator = null;
    if (ModelTrainConf.ALGORITHM.NN.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
        modelCreator = new NNPmmlModelCreator(modelConfig, columnConfigList, isConcise);
        specifCreator = new NNSpecifCreator(modelConfig, columnConfigList);
    } else if (ModelTrainConf.ALGORITHM.LR.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
        modelCreator = new RegressionPmmlModelCreator(modelConfig, columnConfigList, isConcise);
        specifCreator = new RegressionSpecifCreator(modelConfig, columnConfigList);
    } else if (ModelTrainConf.ALGORITHM.GBT.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm()) || ModelTrainConf.ALGORITHM.RF.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
        TreeEnsemblePmmlCreator gbtmodelCreator = new TreeEnsemblePmmlCreator(modelConfig, columnConfigList);
        AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList);
        AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new TreeModelMiningSchemaCreator(modelConfig, columnConfigList);
        return new TreeEnsemblePMMLTranslator(gbtmodelCreator, dataDictionaryCreator, miningSchemaCreator);
    } else {
        throw new RuntimeException("Model not supported: " + modelConfig.getTrain().getAlgorithm());
    }
    AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList, isConcise);
    AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new MiningSchemaCreator(modelConfig, columnConfigList, isConcise);
    AbstractPmmlElementCreator<ModelStats> modelStatsCreator = new ModelStatsCreator(modelConfig, columnConfigList, isConcise);
    AbstractPmmlElementCreator<LocalTransformations> localTransformationsCreator = null;
    ModelNormalizeConf.NormType normType = modelConfig.getNormalizeType();
    if (normType.equals(ModelNormalizeConf.NormType.WOE) || normType.equals(ModelNormalizeConf.NormType.WEIGHT_WOE)) {
        localTransformationsCreator = new WoeLocalTransformCreator(modelConfig, columnConfigList, isConcise);
    } else if (normType == ModelNormalizeConf.NormType.WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WOE_ZSCALE) {
        localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, false);
    } else if (normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCALE) {
        localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, true);
    } else if (normType == ModelNormalizeConf.NormType.ZSCALE_ONEHOT) {
        localTransformationsCreator = new ZscoreOneHotLocalTransformCreator(modelConfig, columnConfigList, isConcise);
    } else {
        localTransformationsCreator = new ZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise);
    }
    return new PMMLTranslator(modelCreator, dataDictionaryCreator, miningSchemaCreator, modelStatsCreator, localTransformationsCreator, specifCreator, isOutBaggingToOne);
}
Also used : AbstractSpecifCreator(ml.shifu.shifu.core.pmml.builder.creator.AbstractSpecifCreator) TreeEnsemblePMMLTranslator(ml.shifu.shifu.core.pmml.TreeEnsemblePMMLTranslator) ModelNormalizeConf(ml.shifu.shifu.container.obj.ModelNormalizeConf) DataDictionary(org.dmg.pmml.DataDictionary) PMMLTranslator(ml.shifu.shifu.core.pmml.PMMLTranslator) TreeEnsemblePMMLTranslator(ml.shifu.shifu.core.pmml.TreeEnsemblePMMLTranslator) LocalTransformations(org.dmg.pmml.LocalTransformations) MiningSchema(org.dmg.pmml.MiningSchema) ModelStats(org.dmg.pmml.ModelStats) Model(org.dmg.pmml.Model)

Example 2 with DataDictionary

use of org.dmg.pmml.DataDictionary in project shifu by ShifuML.

the class DataDictionaryCreator method build.

@Override
public DataDictionary build(BasicML basicML) {
    DataDictionary dict = new DataDictionary();
    List<DataField> fields = new ArrayList<DataField>();
    boolean isSegExpansionMode = columnConfigList.size() > datasetHeaders.length;
    int segSize = segmentExpansions.size();
    if (basicML != null && basicML instanceof BasicFloatNetwork) {
        BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
        Set<Integer> featureSet = bfn.getFeatureSet();
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.getColumnNum() >= datasetHeaders.length) {
                // in order
                break;
            }
            if (isConcise) {
                if (columnConfig.isFinalSelect() && (CollectionUtils.isEmpty(featureSet) || featureSet.contains(columnConfig.getColumnNum())) || columnConfig.isTarget()) {
                    fields.add(convertColumnToDataField(columnConfig));
                } else if (isSegExpansionMode) {
                    // even current column not selected, if segment column selected, we should keep raw column
                    for (int i = 0; i < segSize; i++) {
                        int newIndex = datasetHeaders.length * (i + 1) + columnConfig.getColumnNum();
                        ColumnConfig cc = columnConfigList.get(newIndex);
                        if (cc.isFinalSelect()) {
                            // if one segment feature is selected, we should put raw column in
                            fields.add(convertColumnToDataField(columnConfig));
                            break;
                        }
                    }
                }
            } else {
                fields.add(convertColumnToDataField(columnConfig));
            }
        }
    } else {
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.getColumnNum() >= datasetHeaders.length) {
                // in order
                break;
            }
            if (isConcise) {
                if (columnConfig.isFinalSelect() || columnConfig.isTarget()) {
                    fields.add(convertColumnToDataField(columnConfig));
                } else if (isSegExpansionMode) {
                    // even current column not selected, if segment column selected, we should keep raw column
                    for (int i = 0; i < segSize; i++) {
                        int newIndex = datasetHeaders.length * (i + 1) + columnConfig.getColumnNum();
                        ColumnConfig cc = columnConfigList.get(newIndex);
                        if (cc.isFinalSelect()) {
                            // if one segment feature is selected, we should put raw column in
                            fields.add(convertColumnToDataField(columnConfig));
                            break;
                        }
                    }
                }
            } else {
                fields.add(convertColumnToDataField(columnConfig));
            }
        }
    }
    dict.addDataFields(fields.toArray(new DataField[fields.size()]));
    dict.setNumberOfFields(fields.size());
    return dict;
}
Also used : DataField(org.dmg.pmml.DataField) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ArrayList(java.util.ArrayList) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) DataDictionary(org.dmg.pmml.DataDictionary)

Example 3 with DataDictionary

use of org.dmg.pmml.DataDictionary in project drools by kiegroup.

the class ModelUtilsTest method getTargetFieldsTypeMapWithTargetFieldsWithoutTargets.

@Test
public void getTargetFieldsTypeMapWithTargetFieldsWithoutTargets() {
    final Model model = new RegressionModel();
    final DataDictionary dataDictionary = new DataDictionary();
    final MiningSchema miningSchema = new MiningSchema();
    IntStream.range(0, 3).forEach(i -> {
        final DataField dataField = getRandomDataField();
        dataDictionary.addDataFields(dataField);
        final MiningField miningField = getMiningField(dataField.getName().getValue(), MiningField.UsageType.PREDICTED);
        miningSchema.addMiningFields(miningField);
    });
    model.setMiningSchema(miningSchema);
    Map<String, DATA_TYPE> retrieved = ModelUtils.getTargetFieldsTypeMap(getFieldsFromDataDictionary(dataDictionary), model);
    assertNotNull(retrieved);
    assertEquals(miningSchema.getMiningFields().size(), retrieved.size());
    assertTrue(retrieved instanceof LinkedHashMap);
    final Iterator<Map.Entry<String, DATA_TYPE>> iterator = retrieved.entrySet().iterator();
    for (int i = 0; i < miningSchema.getMiningFields().size(); i++) {
        MiningField miningField = miningSchema.getMiningFields().get(i);
        DataField dataField = dataDictionary.getDataFields().stream().filter(df -> df.getName().equals(miningField.getName())).findFirst().get();
        DATA_TYPE expected = DATA_TYPE.byName(dataField.getDataType().value());
        final Map.Entry<String, DATA_TYPE> next = iterator.next();
        assertEquals(expected, next.getValue());
    }
}
Also used : PMMLModelTestUtils.getMiningField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getMiningField) MiningField(org.dmg.pmml.MiningField) PMMLModelTestUtils.getRandomMiningField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomMiningField) DataDictionary(org.dmg.pmml.DataDictionary) CommonTestingUtils.getFieldsFromDataDictionary(org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary) RegressionModel(org.dmg.pmml.regression.RegressionModel) LinkedHashMap(java.util.LinkedHashMap) MiningSchema(org.dmg.pmml.MiningSchema) DataField(org.dmg.pmml.DataField) PMMLModelTestUtils.getDataField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getDataField) PMMLModelTestUtils.getRandomDataField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomDataField) Model(org.dmg.pmml.Model) RegressionModel(org.dmg.pmml.regression.RegressionModel) DATA_TYPE(org.kie.pmml.api.enums.DATA_TYPE) Map(java.util.Map) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Test(org.junit.Test)

Example 4 with DataDictionary

use of org.dmg.pmml.DataDictionary in project drools by kiegroup.

the class ModelUtilsTest method getTargetFieldTypeWithoutTargetField.

@Test(expected = Exception.class)
public void getTargetFieldTypeWithoutTargetField() {
    final String fieldName = "fieldName";
    MiningField.UsageType usageType = MiningField.UsageType.ACTIVE;
    MiningField miningField = getMiningField(fieldName, usageType);
    final DataField dataField = getDataField(fieldName, OpType.CATEGORICAL, DataType.STRING);
    final DataDictionary dataDictionary = new DataDictionary();
    dataDictionary.addDataFields(dataField);
    final MiningSchema miningSchema = new MiningSchema();
    miningSchema.addMiningFields(miningField);
    final Model model = new RegressionModel();
    model.setMiningSchema(miningSchema);
    ModelUtils.getTargetFieldType(getFieldsFromDataDictionary(dataDictionary), model);
}
Also used : PMMLModelTestUtils.getMiningField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getMiningField) MiningField(org.dmg.pmml.MiningField) PMMLModelTestUtils.getRandomMiningField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomMiningField) DataField(org.dmg.pmml.DataField) PMMLModelTestUtils.getDataField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getDataField) PMMLModelTestUtils.getRandomDataField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomDataField) MiningSchema(org.dmg.pmml.MiningSchema) Model(org.dmg.pmml.Model) RegressionModel(org.dmg.pmml.regression.RegressionModel) DataDictionary(org.dmg.pmml.DataDictionary) CommonTestingUtils.getFieldsFromDataDictionary(org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary) RegressionModel(org.dmg.pmml.regression.RegressionModel) Test(org.junit.Test)

Example 5 with DataDictionary

use of org.dmg.pmml.DataDictionary in project drools by kiegroup.

the class ModelUtilsTest method getTargetFieldsWithTargetFieldsWithTargetsWithoutOptType.

@Test
public void getTargetFieldsWithTargetFieldsWithTargetsWithoutOptType() {
    final Model model = new RegressionModel();
    final DataDictionary dataDictionary = new DataDictionary();
    final MiningSchema miningSchema = new MiningSchema();
    final Targets targets = new Targets();
    IntStream.range(0, 3).forEach(i -> {
        final String fieldName = "fieldName-" + i;
        final DataField dataField = getDataField(fieldName, OpType.CATEGORICAL, DataType.STRING);
        dataDictionary.addDataFields(dataField);
        final MiningField miningField = getMiningField(fieldName, MiningField.UsageType.PREDICTED);
        miningField.setOpType(OpType.CONTINUOUS);
        miningSchema.addMiningFields(miningField);
        final Target targetField = getTarget(fieldName, null);
        targets.addTargets(targetField);
    });
    model.setMiningSchema(miningSchema);
    model.setTargets(targets);
    List<KiePMMLNameOpType> retrieved = ModelUtils.getTargetFields(getFieldsFromDataDictionary(dataDictionary), model);
    assertNotNull(retrieved);
    assertEquals(miningSchema.getMiningFields().size(), retrieved.size());
    retrieved.forEach(kiePMMLNameOpType -> {
        Optional<MiningField> optionalMiningField = miningSchema.getMiningFields().stream().filter(fld -> kiePMMLNameOpType.getName().equals(fld.getName().getValue())).findFirst();
        assertTrue(optionalMiningField.isPresent());
        MiningField miningField = optionalMiningField.get();
        OP_TYPE expected = OP_TYPE.byName(miningField.getOpType().value());
        assertEquals(expected, kiePMMLNameOpType.getOpType());
    });
}
Also used : RESULT_FEATURE(org.kie.pmml.api.enums.RESULT_FEATURE) ModelUtils.getPrefixedName(org.kie.pmml.compiler.api.utils.ModelUtils.getPrefixedName) Arrays(java.util.Arrays) Date(java.util.Date) Model(org.dmg.pmml.Model) MiningSchema(org.dmg.pmml.MiningSchema) OP_TYPE(org.kie.pmml.api.enums.OP_TYPE) Row(org.dmg.pmml.Row) FieldName(org.dmg.pmml.FieldName) FIELD_USAGE_TYPE(org.kie.pmml.api.enums.FIELD_USAGE_TYPE) OpType(org.dmg.pmml.OpType) Map(java.util.Map) InputCell(org.jpmml.model.inlinetable.InputCell) PMMLModelTestUtils.getDataTypes(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getDataTypes) PMMLModelTestUtils.getMiningField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getMiningField) RegressionModel(org.dmg.pmml.regression.RegressionModel) Targets(org.dmg.pmml.Targets) DataType(org.dmg.pmml.DataType) PMMLModelTestUtils.getRandomRow(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomRow) Collectors(java.util.stream.Collectors) PMMLModelTestUtils.getParameterFields(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getParameterFields) DataField(org.dmg.pmml.DataField) List(java.util.List) Assert.assertFalse(org.junit.Assert.assertFalse) Optional(java.util.Optional) PMMLModelTestUtils.getArray(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getArray) PMMLModelTestUtils.getRandomDataType(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomDataType) ParameterField(org.dmg.pmml.ParameterField) IntStream(java.util.stream.IntStream) OutputField(org.dmg.pmml.OutputField) PMMLModelTestUtils.getDataField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getDataField) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) Field(org.dmg.pmml.Field) DerivedField(org.dmg.pmml.DerivedField) KiePMMLInternalException(org.kie.pmml.api.exceptions.KiePMMLInternalException) KiePMMLNameOpType(org.kie.pmml.commons.model.tuples.KiePMMLNameOpType) OutputCell(org.jpmml.model.inlinetable.OutputCell) MiningField(org.dmg.pmml.MiningField) Iterator(java.util.Iterator) Assert.assertNotNull(org.junit.Assert.assertNotNull) Assert.assertTrue(org.junit.Assert.assertTrue) DataDictionary(org.dmg.pmml.DataDictionary) Test(org.junit.Test) PMMLModelTestUtils.getRandomDataField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomDataField) CommonTestingUtils.getFieldsFromDataDictionary(org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary) Target(org.dmg.pmml.Target) DATA_TYPE(org.kie.pmml.api.enums.DATA_TYPE) Array(org.dmg.pmml.Array) PMMLModelTestUtils.getRandomOutputField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomOutputField) PMMLModelTestUtils.getRandomTarget(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomTarget) Assert.assertNull(org.junit.Assert.assertNull) PMMLModelTestUtils.getRandomMiningField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomMiningField) PMMLModelTestUtils.getTarget(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getTarget) PMMLModelTestUtils.getRandomRowWithCells(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomRowWithCells) Assert.assertEquals(org.junit.Assert.assertEquals) PMMLModelTestUtils.getMiningField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getMiningField) MiningField(org.dmg.pmml.MiningField) PMMLModelTestUtils.getRandomMiningField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomMiningField) Targets(org.dmg.pmml.Targets) OP_TYPE(org.kie.pmml.api.enums.OP_TYPE) DataDictionary(org.dmg.pmml.DataDictionary) CommonTestingUtils.getFieldsFromDataDictionary(org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary) RegressionModel(org.dmg.pmml.regression.RegressionModel) KiePMMLNameOpType(org.kie.pmml.commons.model.tuples.KiePMMLNameOpType) Target(org.dmg.pmml.Target) PMMLModelTestUtils.getRandomTarget(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomTarget) PMMLModelTestUtils.getTarget(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getTarget) MiningSchema(org.dmg.pmml.MiningSchema) DataField(org.dmg.pmml.DataField) PMMLModelTestUtils.getDataField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getDataField) PMMLModelTestUtils.getRandomDataField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomDataField) Model(org.dmg.pmml.Model) RegressionModel(org.dmg.pmml.regression.RegressionModel) Test(org.junit.Test)

Aggregations

DataDictionary (org.dmg.pmml.DataDictionary)48 DataField (org.dmg.pmml.DataField)41 Test (org.junit.Test)41 CommonTestingUtils.getFieldsFromDataDictionary (org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary)30 MiningSchema (org.dmg.pmml.MiningSchema)28 MiningField (org.dmg.pmml.MiningField)27 RegressionModel (org.dmg.pmml.regression.RegressionModel)27 PMMLModelTestUtils.getDataField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getDataField)21 PMMLModelTestUtils.getRandomDataField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomDataField)21 Model (org.dmg.pmml.Model)19 PMMLModelTestUtils.getMiningField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getMiningField)17 PMMLModelTestUtils.getRandomMiningField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomMiningField)17 PMML (org.dmg.pmml.PMML)12 OutputField (org.dmg.pmml.OutputField)11 Collectors (java.util.stream.Collectors)10 Assert.assertTrue (org.junit.Assert.assertTrue)10 DATA_TYPE (org.kie.pmml.api.enums.DATA_TYPE)10 OP_TYPE (org.kie.pmml.api.enums.OP_TYPE)10 HasClassLoaderMock (org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock)10 Arrays (java.util.Arrays)9