Search in sources :

Example 6 with LocalTransformations

use of org.dmg.pmml.LocalTransformations in project drools by kiegroup.

the class KiePMMLLocalTransformationsFactoryTest method getKiePMMLTransformationDictionaryVariableDeclaration.

@Test
public void getKiePMMLTransformationDictionaryVariableDeclaration() throws IOException {
    LocalTransformations localTransformations = new LocalTransformations();
    localTransformations.addDerivedFields(getDerivedFields());
    BlockStmt retrieved = KiePMMLLocalTransformationsFactory.getKiePMMLLocalTransformationsVariableDeclaration(localTransformations);
    String text = getFileContent(TEST_01_SOURCE);
    Statement expected = JavaParserUtils.parseBlock(text);
    assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
    List<Class<?>> imports = Arrays.asList(KiePMMLConstant.class, KiePMMLApply.class, KiePMMLDerivedField.class, KiePMMLLocalTransformations.class, Arrays.class, Collections.class);
    commonValidateCompilationWithImports(retrieved, imports);
}
Also used : LocalTransformations(org.dmg.pmml.LocalTransformations) KiePMMLLocalTransformations(org.kie.pmml.commons.transformations.KiePMMLLocalTransformations) Statement(com.github.javaparser.ast.stmt.Statement) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) Test(org.junit.Test)

Example 7 with LocalTransformations

use of org.dmg.pmml.LocalTransformations in project drools by kiegroup.

the class KiePMMLLocalTransformationsInstanceFactoryTest method getKiePMMLLocalTransformations.

@Test
public void getKiePMMLLocalTransformations() {
    final LocalTransformations toConvert = getRandomLocalTransformations();
    KiePMMLLocalTransformations retrieved = KiePMMLLocalTransformationsInstanceFactory.getKiePMMLLocalTransformations(toConvert, Collections.emptyList());
    assertNotNull(retrieved);
    List<DerivedField> derivedFields = toConvert.getDerivedFields();
    List<KiePMMLDerivedField> derivedFieldsToVerify = retrieved.getDerivedFields();
    assertEquals(derivedFields.size(), derivedFieldsToVerify.size());
    derivedFields.forEach(derivedFieldSource -> {
        Optional<KiePMMLDerivedField> derivedFieldToVerify = derivedFieldsToVerify.stream().filter(param -> param.getName().equals(derivedFieldSource.getName().getValue())).findFirst();
        assertTrue(derivedFieldToVerify.isPresent());
        commonVerifyKiePMMLDerivedField(derivedFieldToVerify.get(), derivedFieldSource);
    });
}
Also used : InstanceFactoriesTestCommon.commonVerifyKiePMMLDerivedField(org.kie.pmml.compiler.commons.factories.InstanceFactoriesTestCommon.commonVerifyKiePMMLDerivedField) LocalTransformations(org.dmg.pmml.LocalTransformations) Assert.assertNotNull(org.junit.Assert.assertNotNull) Assert.assertTrue(org.junit.Assert.assertTrue) Test(org.junit.Test) KiePMMLLocalTransformations(org.kie.pmml.commons.transformations.KiePMMLLocalTransformations) List(java.util.List) DerivedField(org.dmg.pmml.DerivedField) PMMLModelTestUtils.getRandomLocalTransformations(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomLocalTransformations) Optional(java.util.Optional) Collections(java.util.Collections) Assert.assertEquals(org.junit.Assert.assertEquals) KiePMMLDerivedField(org.kie.pmml.commons.transformations.KiePMMLDerivedField) LocalTransformations(org.dmg.pmml.LocalTransformations) KiePMMLLocalTransformations(org.kie.pmml.commons.transformations.KiePMMLLocalTransformations) PMMLModelTestUtils.getRandomLocalTransformations(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomLocalTransformations) InstanceFactoriesTestCommon.commonVerifyKiePMMLDerivedField(org.kie.pmml.compiler.commons.factories.InstanceFactoriesTestCommon.commonVerifyKiePMMLDerivedField) KiePMMLDerivedField(org.kie.pmml.commons.transformations.KiePMMLDerivedField) KiePMMLLocalTransformations(org.kie.pmml.commons.transformations.KiePMMLLocalTransformations) InstanceFactoriesTestCommon.commonVerifyKiePMMLDerivedField(org.kie.pmml.compiler.commons.factories.InstanceFactoriesTestCommon.commonVerifyKiePMMLDerivedField) DerivedField(org.dmg.pmml.DerivedField) KiePMMLDerivedField(org.kie.pmml.commons.transformations.KiePMMLDerivedField) Test(org.junit.Test)

Example 8 with LocalTransformations

use of org.dmg.pmml.LocalTransformations in project shifu by ShifuML.

the class PMMLLRModelBuilder method adaptMLModelToPMML.

public RegressionModel adaptMLModelToPMML(ml.shifu.shifu.core.LR lr, RegressionModel pmmlModel) {
    pmmlModel.setNormalizationMethod(NormalizationMethod.LOGIT);
    pmmlModel.setMiningFunction(MiningFunction.REGRESSION);
    RegressionTable table = new RegressionTable();
    table.setIntercept(lr.getBias());
    LocalTransformations lt = pmmlModel.getLocalTransformations();
    List<DerivedField> df = lt.getDerivedFields();
    HashMap<FieldName, FieldName> miningTransformMap = new HashMap<FieldName, FieldName>();
    for (DerivedField dField : df) {
        // Apply z-scale normalization on numerical variables
        if (dField.getExpression() instanceof NormContinuous) {
            miningTransformMap.put(((NormContinuous) dField.getExpression()).getField(), dField.getName());
        } else // Apply bin map on categorical variables
        if (dField.getExpression() instanceof MapValues) {
            miningTransformMap.put(((MapValues) dField.getExpression()).getFieldColumnPairs().get(0).getField(), dField.getName());
        } else if (dField.getExpression() instanceof Discretize) {
            miningTransformMap.put(((Discretize) dField.getExpression()).getField(), dField.getName());
        }
    }
    List<MiningField> miningList = pmmlModel.getMiningSchema().getMiningFields();
    int index = 0;
    for (int i = 0; i < miningList.size(); i++) {
        MiningField mField = miningList.get(i);
        if (mField.getUsageType() != UsageType.ACTIVE)
            continue;
        FieldName mFieldName = mField.getName();
        FieldName fName = mFieldName;
        while (miningTransformMap.containsKey(fName)) {
            fName = miningTransformMap.get(fName);
        }
        NumericPredictor np = new NumericPredictor();
        np.setName(fName);
        np.setCoefficient(lr.getWeights()[index++]);
        table.addNumericPredictors(np);
    }
    pmmlModel.addRegressionTables(table);
    return pmmlModel;
}
Also used : NormContinuous(org.dmg.pmml.NormContinuous) MiningField(org.dmg.pmml.MiningField) HashMap(java.util.HashMap) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) RegressionTable(org.dmg.pmml.regression.RegressionTable) LocalTransformations(org.dmg.pmml.LocalTransformations) MapValues(org.dmg.pmml.MapValues) Discretize(org.dmg.pmml.Discretize) DerivedField(org.dmg.pmml.DerivedField) FieldName(org.dmg.pmml.FieldName)

Example 9 with LocalTransformations

use of org.dmg.pmml.LocalTransformations in project shifu by ShifuML.

the class ZscoreLocalTransformCreator method build.

@Override
public LocalTransformations build(BasicML basicML) {
    LocalTransformations localTransformations = new LocalTransformations();
    if (basicML instanceof BasicFloatNetwork) {
        BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
        Set<Integer> featureSet = bfn.getFeatureSet();
        for (ColumnConfig config : columnConfigList) {
            if (config.isFinalSelect() && (CollectionUtils.isEmpty(featureSet) || featureSet.contains(config.getColumnNum()))) {
                double cutoff = modelConfig.getNormalizeStdDevCutOff();
                List<DerivedField> deriviedFields = config.isCategorical() ? createCategoricalDerivedField(config, cutoff, modelConfig.getNormalizeType()) : createNumericalDerivedField(config, cutoff, modelConfig.getNormalizeType());
                localTransformations.addDerivedFields(deriviedFields.toArray(new DerivedField[deriviedFields.size()]));
            }
        }
    } else {
        for (ColumnConfig config : columnConfigList) {
            if (config.isFinalSelect()) {
                double cutoff = modelConfig.getNormalizeStdDevCutOff();
                List<DerivedField> deriviedFields = config.isCategorical() ? createCategoricalDerivedField(config, cutoff, modelConfig.getNormalizeType()) : createNumericalDerivedField(config, cutoff, modelConfig.getNormalizeType());
                localTransformations.addDerivedFields(deriviedFields.toArray(new DerivedField[deriviedFields.size()]));
            }
        }
    }
    return localTransformations;
}
Also used : LocalTransformations(org.dmg.pmml.LocalTransformations) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) DerivedField(org.dmg.pmml.DerivedField)

Aggregations

LocalTransformations (org.dmg.pmml.LocalTransformations)9 DerivedField (org.dmg.pmml.DerivedField)5 Field (org.dmg.pmml.Field)2 FieldName (org.dmg.pmml.FieldName)2 MiningField (org.dmg.pmml.MiningField)2 Output (org.dmg.pmml.Output)2 Test (org.junit.Test)2 KiePMMLLocalTransformations (org.kie.pmml.commons.transformations.KiePMMLLocalTransformations)2 BlockStmt (com.github.javaparser.ast.stmt.BlockStmt)1 Statement (com.github.javaparser.ast.stmt.Statement)1 ArrayList (java.util.ArrayList)1 Collections (java.util.Collections)1 HashMap (java.util.HashMap)1 List (java.util.List)1 Optional (java.util.Optional)1 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)1 ModelNormalizeConf (ml.shifu.shifu.container.obj.ModelNormalizeConf)1 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)1 PMMLTranslator (ml.shifu.shifu.core.pmml.PMMLTranslator)1 TreeEnsemblePMMLTranslator (ml.shifu.shifu.core.pmml.TreeEnsemblePMMLTranslator)1