Search in sources :

Example 11 with DerivedField

use of org.dmg.pmml.DerivedField in project shifu by ShifuML.

the class NeuralNetworkModelIntegrator method getLocalTranformations.

private LocalTransformations getLocalTranformations(NeuralNetwork model) {
    // delete target
    List<DerivedField> derivedFields = model.getLocalTransformations().getDerivedFields();
    // add bias
    DerivedField field = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).setName(new FieldName(PluginConstants.biasValue));
    field.setExpression(new Constant(String.valueOf(PluginConstants.bias)));
    derivedFields.add(field);
    return new LocalTransformations().addDerivedFields(derivedFields.toArray(new DerivedField[derivedFields.size()]));
}
Also used : LocalTransformations(org.dmg.pmml.LocalTransformations) Constant(org.dmg.pmml.Constant) DerivedField(org.dmg.pmml.DerivedField) FieldName(org.dmg.pmml.FieldName)

Example 12 with DerivedField

use of org.dmg.pmml.DerivedField in project jpmml-r by jpmml.

the class FormulaUtil method createFormula.

public static Formula createFormula(RExp terms, FormulaContext context, RExpEncoder encoder) {
    Formula formula = new Formula(encoder);
    RIntegerVector factors = (RIntegerVector) terms.getAttributeValue("factors");
    RStringVector dataClasses = (RStringVector) terms.getAttributeValue("dataClasses");
    RStringVector variableRows = factors.dimnames(0);
    RStringVector termColumns = factors.dimnames(1);
    VariableMap expressionFields = new VariableMap();
    for (int i = 0; i < variableRows.size(); i++) {
        String variable = variableRows.getDequotedValue(i);
        FieldName name = FieldName.create(variable);
        OpType opType = OpType.CONTINUOUS;
        DataType dataType = RExpUtil.getDataType(dataClasses.getValue(variable));
        List<String> categories = context.getCategories(variable);
        if (categories != null && categories.size() > 0) {
            opType = OpType.CATEGORICAL;
        }
        Expression expression = null;
        FieldName shortName = name;
        expression: if (variable.indexOf('(') > -1 && variable.indexOf(')') > -1) {
            FunctionExpression functionExpression;
            try {
                functionExpression = (FunctionExpression) ExpressionTranslator.translateExpression(variable);
            } catch (Exception e) {
                break expression;
            }
            if (functionExpression.hasId("base", "cut")) {
                expression = encodeCutExpression(functionExpression, categories, expressionFields, encoder);
            } else if (functionExpression.hasId("base", "I")) {
                expression = encodeIdentityExpression(functionExpression, expressionFields, encoder);
            } else if (functionExpression.hasId("base", "ifelse")) {
                expression = encodeIfElseExpression(functionExpression, expressionFields, encoder);
            } else if (functionExpression.hasId("plyr", "mapvalues")) {
                expression = encodeMapValuesExpression(functionExpression, categories, expressionFields, encoder);
            } else if (functionExpression.hasId("plyr", "revalue")) {
                expression = encodeReValueExpression(functionExpression, categories, expressionFields, encoder);
            } else {
                break expression;
            }
            FunctionExpression.Argument xArgument = functionExpression.getArgument("x", 0);
            String value = (xArgument.formatExpression()).trim();
            shortName = FieldName.create(functionExpression.hasId("base", "I") ? value : (functionExpression.getFunction() + "(" + value + ")"));
        }
        if (expression != null) {
            DerivedField derivedField = encoder.createDerivedField(name, opType, dataType, expression).addExtensions(createExtension(variable));
            if (categories != null && categories.size() > 0) {
                formula.addField(derivedField, categories);
            } else {
                formula.addField(derivedField);
            }
            if (!(name).equals(shortName)) {
                encoder.renameField(name, shortName);
            }
        } else {
            if ((DataType.BOOLEAN).equals(dataType)) {
                categories = Arrays.asList("false", "true");
            }
            if (categories != null && categories.size() > 0) {
                DataField dataField = encoder.createDataField(name, OpType.CATEGORICAL, dataType, categories);
                List<String> categoryNames;
                List<String> categoryValues;
                switch(dataType) {
                    case BOOLEAN:
                        categoryNames = Arrays.asList("FALSE", "TRUE");
                        categoryValues = Arrays.asList("false", "true");
                        break;
                    default:
                        categoryNames = categories;
                        categoryValues = categories;
                        break;
                }
                formula.addField(dataField, categoryNames, categoryValues);
            } else {
                DataField dataField = encoder.createDataField(name, OpType.CONTINUOUS, dataType);
                formula.addField(dataField);
            }
        }
    }
    Collection<Map.Entry<FieldName, List<String>>> entries = expressionFields.entrySet();
    for (Map.Entry<FieldName, List<String>> entry : entries) {
        FieldName name = entry.getKey();
        List<String> categories = entry.getValue();
        DataField dataField = encoder.getDataField(name);
        if (dataField == null) {
            OpType opType = OpType.CONTINUOUS;
            DataType dataType = DataType.DOUBLE;
            if (categories != null && categories.size() > 0) {
                opType = OpType.CATEGORICAL;
            }
            RGenericVector data = context.getData();
            if (data != null && data.hasValue(name.getValue())) {
                RVector<?> column = (RVector<?>) data.getValue(name.getValue());
                dataType = column.getDataType();
            }
            dataField = encoder.createDataField(name, opType, dataType, categories);
        }
    }
    return formula;
}
Also used : DataField(org.dmg.pmml.DataField) Expression(org.dmg.pmml.Expression) DataType(org.dmg.pmml.DataType) OpType(org.dmg.pmml.OpType) ArrayList(java.util.ArrayList) List(java.util.List) FieldName(org.dmg.pmml.FieldName) DerivedField(org.dmg.pmml.DerivedField) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map)

Example 13 with DerivedField

use of org.dmg.pmml.DerivedField in project jpmml-r by jpmml.

the class EarthConverter method encodeSchema.

@Override
public void encodeSchema(RExpEncoder encoder) {
    RGenericVector earth = getObject();
    RDoubleVector dirs = (RDoubleVector) earth.getValue("dirs");
    RDoubleVector cuts = (RDoubleVector) earth.getValue("cuts");
    RDoubleVector selectedTerms = (RDoubleVector) earth.getValue("selected.terms");
    RDoubleVector coefficients = (RDoubleVector) earth.getValue("coefficients");
    RExp terms = earth.getValue("terms");
    final RGenericVector xlevels;
    try {
        xlevels = (RGenericVector) earth.getValue("xlevels");
    } catch (IllegalArgumentException iae) {
        throw new IllegalArgumentException("No variable levels information. Please initialize the \'xlevels\' element", iae);
    }
    RStringVector dirsRows = dirs.dimnames(0);
    RStringVector dirsColumns = dirs.dimnames(1);
    RStringVector cutsRows = cuts.dimnames(0);
    RStringVector cutsColumns = cuts.dimnames(1);
    if (!(dirsRows.getValues()).equals(cutsRows.getValues()) || !(dirsColumns.getValues()).equals(cutsColumns.getValues())) {
        throw new IllegalArgumentException();
    }
    int rows = dirsRows.size();
    int columns = dirsColumns.size();
    List<String> predictorNames = dirsColumns.getValues();
    FormulaContext context = new FormulaContext() {

        @Override
        public List<String> getCategories(String variable) {
            if (xlevels.hasValue(variable)) {
                RStringVector levels = (RStringVector) xlevels.getValue(variable);
                return levels.getValues();
            }
            return null;
        }

        @Override
        public RGenericVector getData() {
            return null;
        }
    };
    Formula formula = FormulaUtil.createFormula(terms, context, encoder);
    // Dependent variable
    {
        RStringVector yNames = coefficients.dimnames(1);
        FieldName name = FieldName.create(yNames.asScalar());
        DataField dataField = (DataField) encoder.getField(name);
        encoder.setLabel(dataField);
    }
    // Independent variables
    for (int i = 1; i < selectedTerms.size(); i++) {
        int termIndex = ValueUtil.asInt(selectedTerms.getValue(i)) - 1;
        List<Double> dirsRow = FortranMatrixUtil.getRow(dirs.getValues(), rows, columns, termIndex);
        List<Double> cutsRow = FortranMatrixUtil.getRow(cuts.getValues(), rows, columns, termIndex);
        List<Feature> features = new ArrayList<>();
        predictors: for (int j = 0; j < predictorNames.size(); j++) {
            String predictorName = predictorNames.get(j);
            int dir = ValueUtil.asInt(dirsRow.get(j));
            double cut = cutsRow.get(j);
            if (dir == 0) {
                continue predictors;
            }
            Feature feature = formula.resolveFeature(predictorName);
            switch(dir) {
                case -1:
                case 1:
                    {
                        feature = feature.toContinuousFeature();
                        FieldName name = FieldName.create(formatHingeFunction(dir, feature, cut));
                        DerivedField derivedField = encoder.getDerivedField(name);
                        if (derivedField == null) {
                            Apply apply = createHingeFunction(dir, feature, cut);
                            derivedField = encoder.createDerivedField(name, OpType.CONTINUOUS, DataType.DOUBLE, apply);
                        }
                        feature = new ContinuousFeature(encoder, derivedField);
                    }
                    break;
                case 2:
                    break;
                default:
                    throw new IllegalArgumentException();
            }
            features.add(feature);
        }
        Feature feature;
        if (features.size() == 1) {
            feature = features.get(0);
        } else if (features.size() > 1) {
            feature = new InteractionFeature(encoder, FieldName.create(dirsRows.getValue(i)), DataType.DOUBLE, features);
        } else {
            throw new IllegalArgumentException();
        }
        encoder.addFeature(feature);
    }
}
Also used : InteractionFeature(org.jpmml.converter.InteractionFeature) Apply(org.dmg.pmml.Apply) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) InteractionFeature(org.jpmml.converter.InteractionFeature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) DataField(org.dmg.pmml.DataField) FieldName(org.dmg.pmml.FieldName) DerivedField(org.dmg.pmml.DerivedField)

Example 14 with DerivedField

use of org.dmg.pmml.DerivedField in project jpmml-r by jpmml.

the class Formula method addField.

public void addField(Field<?> field) {
    RExpEncoder encoder = getEncoder();
    Feature feature = new ContinuousFeature(encoder, field);
    if (field instanceof DerivedField) {
        DerivedField derivedField = (DerivedField) field;
        Expression expression = derivedField.getExpression();
        if (expression instanceof Apply) {
            Apply apply = (Apply) expression;
            if (checkApply(apply, "pow", FieldRef.class, Constant.class)) {
                List<Expression> expressions = apply.getExpressions();
                FieldRef fieldRef = (FieldRef) expressions.get(0);
                Constant constant = (Constant) expressions.get(1);
                try {
                    int power = Integer.parseInt(constant.getValue());
                    feature = new PowerFeature(encoder, fieldRef.getField(), DataType.DOUBLE, power);
                } catch (NumberFormatException nfe) {
                // Ignored
                }
            }
        }
    }
    putFeature(field.getName(), feature);
    this.fields.add(field);
}
Also used : PowerFeature(org.jpmml.converter.PowerFeature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) FieldRef(org.dmg.pmml.FieldRef) Expression(org.dmg.pmml.Expression) Apply(org.dmg.pmml.Apply) Constant(org.dmg.pmml.Constant) Feature(org.jpmml.converter.Feature) PowerFeature(org.jpmml.converter.PowerFeature) BinaryFeature(org.jpmml.converter.BinaryFeature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) InteractionFeature(org.jpmml.converter.InteractionFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) DerivedField(org.dmg.pmml.DerivedField)

Example 15 with DerivedField

use of org.dmg.pmml.DerivedField in project jpmml-sparkml by jpmml.

the class MaxAbsScalerModelConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    MaxAbsScalerModel transformer = getTransformer();
    List<Feature> features = encoder.getFeatures(transformer.getInputCol());
    Vector maxAbs = transformer.maxAbs();
    if (maxAbs.size() != features.size()) {
        throw new IllegalArgumentException();
    }
    List<Feature> result = new ArrayList<>();
    for (int i = 0; i < features.size(); i++) {
        Feature feature = features.get(i);
        double maxAbsUnzero = maxAbs.apply(i);
        if (maxAbsUnzero == 0d) {
            maxAbsUnzero = 1d;
        }
        if (!ValueUtil.isOne(maxAbsUnzero)) {
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            Expression expression = PMMLUtil.createApply("/", continuousFeature.ref(), PMMLUtil.createConstant(maxAbsUnzero));
            DerivedField derivedField = encoder.createDerivedField(formatName(transformer, i), OpType.CONTINUOUS, DataType.DOUBLE, expression);
            feature = new ContinuousFeature(encoder, derivedField);
        }
        result.add(feature);
    }
    return result;
}
Also used : ContinuousFeature(org.jpmml.converter.ContinuousFeature) Expression(org.dmg.pmml.Expression) MaxAbsScalerModel(org.apache.spark.ml.feature.MaxAbsScalerModel) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Vector(org.apache.spark.ml.linalg.Vector) DerivedField(org.dmg.pmml.DerivedField)

Aggregations

DerivedField (org.dmg.pmml.DerivedField)27 ArrayList (java.util.ArrayList)16 Feature (org.jpmml.converter.Feature)14 ContinuousFeature (org.jpmml.converter.ContinuousFeature)13 FieldName (org.dmg.pmml.FieldName)10 Apply (org.dmg.pmml.Apply)9 Expression (org.dmg.pmml.Expression)8 NormContinuous (org.dmg.pmml.NormContinuous)6 CategoricalFeature (org.jpmml.converter.CategoricalFeature)5 DataField (org.dmg.pmml.DataField)4 Discretize (org.dmg.pmml.Discretize)4 MapValues (org.dmg.pmml.MapValues)4 Vector (org.apache.spark.ml.linalg.Vector)3 FieldRef (org.dmg.pmml.FieldRef)3 LocalTransformations (org.dmg.pmml.LocalTransformations)3 HashMap (java.util.HashMap)2 List (java.util.List)2 Map (java.util.Map)2 Constant (org.dmg.pmml.Constant)2 DiscretizeBin (org.dmg.pmml.DiscretizeBin)2