Search in sources :

Example 51 with DataField

use of org.dmg.pmml.DataField in project drools by kiegroup.

the class KiePMMLClassificationTableFactoryTest method setStaticGetter.

@Test
public void setStaticGetter() throws IOException {
    String variableName = "variableName";
    RegressionTable regressionTableProf = getRegressionTable(3.5, "professional");
    RegressionTable regressionTableCler = getRegressionTable(27.4, "clerical");
    OutputField outputFieldCat = getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
    OutputField outputFieldNum = getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
    OutputField outputFieldPrev = getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
    String targetField = "targetField";
    DataField dataField = new DataField();
    dataField.setName(FieldName.create(targetField));
    dataField.setOpType(OpType.CATEGORICAL);
    DataDictionary dataDictionary = new DataDictionary();
    dataDictionary.addDataFields(dataField);
    RegressionModel regressionModel = new RegressionModel();
    regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
    regressionModel.addRegressionTables(regressionTableProf, regressionTableCler);
    regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
    Output output = new Output();
    output.addOutputFields(outputFieldCat, outputFieldNum, outputFieldPrev);
    regressionModel.setOutput(output);
    MiningField miningField = new MiningField();
    miningField.setUsageType(MiningField.UsageType.TARGET);
    miningField.setName(dataField.getName());
    MiningSchema miningSchema = new MiningSchema();
    miningSchema.addMiningFields(miningField);
    regressionModel.setMiningSchema(miningSchema);
    PMML pmml = new PMML();
    pmml.setDataDictionary(dataDictionary);
    pmml.addModels(regressionModel);
    final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
    final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
    final LinkedHashMap<String, KiePMMLTableSourceCategory> regressionTablesMap = new LinkedHashMap<>();
    regressionModel.getRegressionTables().forEach(regressionTable -> {
        String key = "defpack." + regressionTable.getTargetCategory().toString().toUpperCase();
        KiePMMLTableSourceCategory value = new KiePMMLTableSourceCategory("", regressionTable.getTargetCategory().toString());
        regressionTablesMap.put(key, value);
    });
    final MethodDeclaration staticGetterMethod = STATIC_GETTER_METHOD.clone();
    KiePMMLClassificationTableFactory.setStaticGetter(compilationDTO, regressionTablesMap, staticGetterMethod, variableName);
    String text = getFileContent(TEST_02_SOURCE);
    MethodDeclaration expected = JavaParserUtils.parseMethod(text);
    assertTrue(JavaParserUtils.equalsNode(expected, staticGetterMethod));
}
Also used : MiningField(org.dmg.pmml.MiningField) MethodDeclaration(com.github.javaparser.ast.body.MethodDeclaration) DataDictionary(org.dmg.pmml.DataDictionary) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) RegressionTable(org.dmg.pmml.regression.RegressionTable) RegressionModel(org.dmg.pmml.regression.RegressionModel) LinkedHashMap(java.util.LinkedHashMap) DataField(org.dmg.pmml.DataField) MiningSchema(org.dmg.pmml.MiningSchema) Output(org.dmg.pmml.Output) KiePMMLTableSourceCategory(org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory) OutputField(org.dmg.pmml.OutputField) PMML(org.dmg.pmml.PMML) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) Test(org.junit.Test)

Example 52 with DataField

use of org.dmg.pmml.DataField in project drools by kiegroup.

the class RegressionCompilationDTO method isRegression.

public boolean isRegression() {
    final DataField targetDataField = getTargetDataField();
    final OpType targetOpType = targetDataField != null ? targetDataField.getOpType() : null;
    return Objects.equals(MiningFunction.REGRESSION, getMiningFunction()) && (targetDataField == null || Objects.equals(OpType.CONTINUOUS, targetOpType));
}
Also used : DataField(org.dmg.pmml.DataField) OpType(org.dmg.pmml.OpType)

Example 53 with DataField

use of org.dmg.pmml.DataField in project drools by kiegroup.

the class KiePMMLCharacteristicFactoryTest method getAttributeVariableDeclarationWithComplexPartialScore.

@Test
public void getAttributeVariableDeclarationWithComplexPartialScore() throws IOException {
    final String variableName = "variableName";
    Array.Type arrayType = Array.Type.STRING;
    List<String> values1 = getStringObjects(arrayType, 4);
    Attribute attribute1 = getAttribute(values1, 1);
    List<String> values2 = getStringObjects(arrayType, 4);
    Attribute attribute2 = getAttribute(values2, 2);
    CompoundPredicate compoundPredicate1 = (CompoundPredicate) attribute1.getPredicate();
    CompoundPredicate compoundPredicate2 = (CompoundPredicate) attribute2.getPredicate();
    DataDictionary dataDictionary = new DataDictionary();
    for (Predicate predicate : compoundPredicate1.getPredicates()) {
        DataField toAdd = null;
        if (predicate instanceof SimplePredicate) {
            toAdd = new DataField();
            toAdd.setName(((SimplePredicate) predicate).getField());
            toAdd.setDataType(DataType.DOUBLE);
        } else if (predicate instanceof SimpleSetPredicate) {
            toAdd = new DataField();
            toAdd.setName(((SimpleSetPredicate) predicate).getField());
            toAdd.setDataType(DataType.DOUBLE);
        }
        if (toAdd != null) {
            dataDictionary.addDataFields(toAdd);
        }
    }
    for (Predicate predicate : compoundPredicate2.getPredicates()) {
        DataField toAdd = null;
        if (predicate instanceof SimplePredicate) {
            toAdd = new DataField();
            toAdd.setName(((SimplePredicate) predicate).getField());
            toAdd.setDataType(DataType.DOUBLE);
        } else if (predicate instanceof SimpleSetPredicate) {
            toAdd = new DataField();
            toAdd.setName(((SimpleSetPredicate) predicate).getField());
            toAdd.setDataType(DataType.DOUBLE);
        }
        if (toAdd != null) {
            dataDictionary.addDataFields(toAdd);
        }
    }
    String valuesString1 = values1.stream().map(valueString -> "\"" + valueString + "\"").collect(Collectors.joining(","));
    String valuesString2 = values2.stream().map(valueString -> "\"" + valueString + "\"").collect(Collectors.joining(","));
    Characteristic characteristic = new Characteristic();
    characteristic.addAttributes(attribute1, attribute2);
    characteristic.setBaselineScore(22);
    characteristic.setReasonCode(REASON_CODE);
    BlockStmt retrieved = KiePMMLCharacteristicFactory.getCharacteristicVariableDeclaration(variableName, characteristic, getFieldsFromDataDictionary(dataDictionary));
    String text = getFileContent(TEST_01_SOURCE);
    Statement expected = JavaParserUtils.parseBlock(String.format(text, variableName, valuesString1, valuesString2, characteristic.getBaselineScore(), characteristic.getReasonCode()));
    assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
    List<Class<?>> imports = Arrays.asList(KiePMMLAttribute.class, KiePMMLCharacteristic.class, KiePMMLComplexPartialScore.class, KiePMMLCompoundPredicate.class, KiePMMLConstant.class, KiePMMLSimplePredicate.class, KiePMMLSimpleSetPredicate.class, Arrays.class, Collections.class);
    commonValidateCompilationWithImports(retrieved, imports);
}
Also used : KiePMMLConstant(org.kie.pmml.commons.model.expressions.KiePMMLConstant) Arrays(java.util.Arrays) Predicate(org.dmg.pmml.Predicate) PMMLModelTestUtils.getSimplePredicate(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getSimplePredicate) KiePMMLAttribute(org.kie.pmml.models.scorecard.model.KiePMMLAttribute) KiePMMLComplexPartialScore(org.kie.pmml.models.scorecard.model.KiePMMLComplexPartialScore) Characteristic(org.dmg.pmml.scorecard.Characteristic) ComplexPartialScore(org.dmg.pmml.scorecard.ComplexPartialScore) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate) CompoundPredicate(org.dmg.pmml.CompoundPredicate) JavaParserUtils(org.kie.pmml.compiler.commons.utils.JavaParserUtils) KiePMMLSimplePredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimplePredicate) DataType(org.dmg.pmml.DataType) KiePMMLSimpleSetPredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimpleSetPredicate) KiePMMLCompoundPredicate(org.kie.pmml.commons.model.predicates.KiePMMLCompoundPredicate) Assert.assertTrue(org.junit.Assert.assertTrue) IOException(java.io.IOException) DataDictionary(org.dmg.pmml.DataDictionary) Test(org.junit.Test) Statement(com.github.javaparser.ast.stmt.Statement) Attribute(org.dmg.pmml.scorecard.Attribute) CommonTestingUtils.getFieldsFromDataDictionary(org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary) Collectors(java.util.stream.Collectors) Array(org.dmg.pmml.Array) FileUtils.getFileContent(org.kie.test.util.filesystem.FileUtils.getFileContent) DataField(org.dmg.pmml.DataField) List(java.util.List) SimplePredicate(org.dmg.pmml.SimplePredicate) PMMLModelTestUtils.getStringObjects(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getStringObjects) CodegenTestUtils.commonValidateCompilationWithImports(org.kie.pmml.compiler.commons.testutils.CodegenTestUtils.commonValidateCompilationWithImports) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) KiePMMLSimpleSetPredicateFactoryTest.getSimpleSetPredicate(org.kie.pmml.compiler.commons.codegenfactories.KiePMMLSimpleSetPredicateFactoryTest.getSimpleSetPredicate) Constant(org.dmg.pmml.Constant) Collections(java.util.Collections) KiePMMLCharacteristic(org.kie.pmml.models.scorecard.model.KiePMMLCharacteristic) KiePMMLAttribute(org.kie.pmml.models.scorecard.model.KiePMMLAttribute) Attribute(org.dmg.pmml.scorecard.Attribute) Statement(com.github.javaparser.ast.stmt.Statement) Characteristic(org.dmg.pmml.scorecard.Characteristic) KiePMMLCharacteristic(org.kie.pmml.models.scorecard.model.KiePMMLCharacteristic) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) DataDictionary(org.dmg.pmml.DataDictionary) CommonTestingUtils.getFieldsFromDataDictionary(org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary) PMMLModelTestUtils.getSimplePredicate(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getSimplePredicate) KiePMMLSimplePredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimplePredicate) SimplePredicate(org.dmg.pmml.SimplePredicate) Predicate(org.dmg.pmml.Predicate) PMMLModelTestUtils.getSimplePredicate(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getSimplePredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate) CompoundPredicate(org.dmg.pmml.CompoundPredicate) KiePMMLSimplePredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimplePredicate) KiePMMLSimpleSetPredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimpleSetPredicate) KiePMMLCompoundPredicate(org.kie.pmml.commons.model.predicates.KiePMMLCompoundPredicate) SimplePredicate(org.dmg.pmml.SimplePredicate) KiePMMLSimpleSetPredicateFactoryTest.getSimpleSetPredicate(org.kie.pmml.compiler.commons.codegenfactories.KiePMMLSimpleSetPredicateFactoryTest.getSimpleSetPredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate) KiePMMLSimpleSetPredicate(org.kie.pmml.commons.model.predicates.KiePMMLSimpleSetPredicate) KiePMMLSimpleSetPredicateFactoryTest.getSimpleSetPredicate(org.kie.pmml.compiler.commons.codegenfactories.KiePMMLSimpleSetPredicateFactoryTest.getSimpleSetPredicate) Array(org.dmg.pmml.Array) DataField(org.dmg.pmml.DataField) CompoundPredicate(org.dmg.pmml.CompoundPredicate) KiePMMLCompoundPredicate(org.kie.pmml.commons.model.predicates.KiePMMLCompoundPredicate) Test(org.junit.Test)

Example 54 with DataField

use of org.dmg.pmml.DataField in project jpmml-r by jpmml.

the class FormulaUtil method createFormula.

public static Formula createFormula(RExp terms, FormulaContext context, RExpEncoder encoder) {
    Formula formula = new Formula(encoder);
    RIntegerVector factors = (RIntegerVector) terms.getAttributeValue("factors");
    RStringVector dataClasses = (RStringVector) terms.getAttributeValue("dataClasses");
    RStringVector variableRows = factors.dimnames(0);
    RStringVector termColumns = factors.dimnames(1);
    VariableMap expressionFields = new VariableMap();
    for (int i = 0; i < variableRows.size(); i++) {
        String variable = variableRows.getDequotedValue(i);
        FieldName name = FieldName.create(variable);
        OpType opType = OpType.CONTINUOUS;
        DataType dataType = RExpUtil.getDataType(dataClasses.getValue(variable));
        List<String> categories = context.getCategories(variable);
        if (categories != null && categories.size() > 0) {
            opType = OpType.CATEGORICAL;
        }
        Expression expression = null;
        FieldName shortName = name;
        expression: if (variable.indexOf('(') > -1 && variable.indexOf(')') > -1) {
            FunctionExpression functionExpression;
            try {
                functionExpression = (FunctionExpression) ExpressionTranslator.translateExpression(variable);
            } catch (Exception e) {
                break expression;
            }
            if (functionExpression.hasId("base", "cut")) {
                expression = encodeCutExpression(functionExpression, categories, expressionFields, encoder);
            } else if (functionExpression.hasId("base", "I")) {
                expression = encodeIdentityExpression(functionExpression, expressionFields, encoder);
            } else if (functionExpression.hasId("base", "ifelse")) {
                expression = encodeIfElseExpression(functionExpression, expressionFields, encoder);
            } else if (functionExpression.hasId("plyr", "mapvalues")) {
                expression = encodeMapValuesExpression(functionExpression, categories, expressionFields, encoder);
            } else if (functionExpression.hasId("plyr", "revalue")) {
                expression = encodeReValueExpression(functionExpression, categories, expressionFields, encoder);
            } else {
                break expression;
            }
            FunctionExpression.Argument xArgument = functionExpression.getArgument("x", 0);
            String value = (xArgument.formatExpression()).trim();
            shortName = FieldName.create(functionExpression.hasId("base", "I") ? value : (functionExpression.getFunction() + "(" + value + ")"));
        }
        if (expression != null) {
            DerivedField derivedField = encoder.createDerivedField(name, opType, dataType, expression).addExtensions(createExtension(variable));
            if (categories != null && categories.size() > 0) {
                formula.addField(derivedField, categories);
            } else {
                formula.addField(derivedField);
            }
            if (!(name).equals(shortName)) {
                encoder.renameField(name, shortName);
            }
        } else {
            if ((DataType.BOOLEAN).equals(dataType)) {
                categories = Arrays.asList("false", "true");
            }
            if (categories != null && categories.size() > 0) {
                DataField dataField = encoder.createDataField(name, OpType.CATEGORICAL, dataType, categories);
                List<String> categoryNames;
                List<String> categoryValues;
                switch(dataType) {
                    case BOOLEAN:
                        categoryNames = Arrays.asList("FALSE", "TRUE");
                        categoryValues = Arrays.asList("false", "true");
                        break;
                    default:
                        categoryNames = categories;
                        categoryValues = categories;
                        break;
                }
                formula.addField(dataField, categoryNames, categoryValues);
            } else {
                DataField dataField = encoder.createDataField(name, OpType.CONTINUOUS, dataType);
                formula.addField(dataField);
            }
        }
    }
    Collection<Map.Entry<FieldName, List<String>>> entries = expressionFields.entrySet();
    for (Map.Entry<FieldName, List<String>> entry : entries) {
        FieldName name = entry.getKey();
        List<String> categories = entry.getValue();
        DataField dataField = encoder.getDataField(name);
        if (dataField == null) {
            OpType opType = OpType.CONTINUOUS;
            DataType dataType = DataType.DOUBLE;
            if (categories != null && categories.size() > 0) {
                opType = OpType.CATEGORICAL;
            }
            RGenericVector data = context.getData();
            if (data != null && data.hasValue(name.getValue())) {
                RVector<?> column = (RVector<?>) data.getValue(name.getValue());
                dataType = column.getDataType();
            }
            dataField = encoder.createDataField(name, opType, dataType, categories);
        }
    }
    return formula;
}
Also used : DataField(org.dmg.pmml.DataField) Expression(org.dmg.pmml.Expression) DataType(org.dmg.pmml.DataType) OpType(org.dmg.pmml.OpType) ArrayList(java.util.ArrayList) List(java.util.List) FieldName(org.dmg.pmml.FieldName) DerivedField(org.dmg.pmml.DerivedField) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map)

Example 55 with DataField

use of org.dmg.pmml.DataField in project jpmml-r by jpmml.

the class GLMConverter method encodeSchema.

@Override
public void encodeSchema(RExpEncoder encoder) {
    RGenericVector glm = getObject();
    RGenericVector family = (RGenericVector) glm.getValue("family");
    RGenericVector model = (RGenericVector) glm.getValue("model");
    RStringVector familyFamily = (RStringVector) family.getValue("family");
    super.encodeSchema(encoder);
    MiningFunction miningFunction = getMiningFunction(familyFamily.asScalar());
    switch(miningFunction) {
        case CLASSIFICATION:
            Label label = encoder.getLabel();
            RIntegerVector variable = (RIntegerVector) model.getValue((label.getName()).getValue());
            DataField dataField = (DataField) encoder.toCategorical(label.getName(), RExpUtil.getFactorLevels(variable));
            encoder.setLabel(dataField);
            break;
        default:
            break;
    }
}
Also used : DataField(org.dmg.pmml.DataField) CategoricalLabel(org.jpmml.converter.CategoricalLabel) Label(org.jpmml.converter.Label) MiningFunction(org.dmg.pmml.MiningFunction)

Aggregations

DataField (org.dmg.pmml.DataField)101 Test (org.junit.Test)51 DataDictionary (org.dmg.pmml.DataDictionary)42 MiningField (org.dmg.pmml.MiningField)42 MiningSchema (org.dmg.pmml.MiningSchema)30 PMMLModelTestUtils.getRandomDataField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomDataField)28 RegressionModel (org.dmg.pmml.regression.RegressionModel)27 CommonTestingUtils.getFieldsFromDataDictionary (org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary)27 FieldName (org.dmg.pmml.FieldName)24 Model (org.dmg.pmml.Model)24 PMMLModelTestUtils.getDataField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getDataField)22 DataType (org.dmg.pmml.DataType)19 OutputField (org.dmg.pmml.OutputField)19 PMMLModelTestUtils.getRandomMiningField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomMiningField)19 PMMLModelTestUtils.getMiningField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getMiningField)18 ArrayList (java.util.ArrayList)17 List (java.util.List)17 PMML (org.dmg.pmml.PMML)17 Collectors (java.util.stream.Collectors)16 OpType (org.dmg.pmml.OpType)15