Search in sources :

Example 26 with Feature

use of org.jpmml.converter.Feature in project jpmml-r by jpmml.

the class GBMConverter method encodeNode.

private void encodeNode(Node node, int i, RGenericVector tree, RGenericVector c_splits, Schema schema) {
    RIntegerVector splitVar = (RIntegerVector) tree.getValue(0);
    RDoubleVector splitCodePred = (RDoubleVector) tree.getValue(1);
    RIntegerVector leftNode = (RIntegerVector) tree.getValue(2);
    RIntegerVector rightNode = (RIntegerVector) tree.getValue(3);
    RIntegerVector missingNode = (RIntegerVector) tree.getValue(4);
    RDoubleVector prediction = (RDoubleVector) tree.getValue(7);
    Predicate missingPredicate;
    Predicate leftPredicate;
    Predicate rightPredicate;
    Integer var = splitVar.getValue(i);
    if (var != -1) {
        Feature feature = schema.getFeature(var);
        missingPredicate = createSimplePredicate(feature, SimplePredicate.Operator.IS_MISSING, null);
        Double split = splitCodePred.getValue(i);
        if (feature instanceof CategoricalFeature) {
            CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
            List<String> values = categoricalFeature.getValues();
            int index = ValueUtil.asInt(split);
            RIntegerVector c_split = (RIntegerVector) c_splits.getValue(index);
            List<Integer> splitValues = c_split.getValues();
            leftPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, true));
            rightPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, false));
        } else {
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            String value = ValueUtil.formatValue(split);
            leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, value);
            rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, value);
        }
    } else {
        Double value = prediction.getValue(i);
        node.setScore(ValueUtil.formatValue(value));
        return;
    }
    Integer missing = missingNode.getValue(i);
    if (missing != -1) {
        Node missingChild = new Node().setId(String.valueOf(missing + 1)).setPredicate(missingPredicate);
        encodeNode(missingChild, missing, tree, c_splits, schema);
        node.addNodes(missingChild);
    }
    Integer left = leftNode.getValue(i);
    if (left != -1) {
        Node leftChild = new Node().setId(String.valueOf(left + 1)).setPredicate(leftPredicate);
        encodeNode(leftChild, left, tree, c_splits, schema);
        node.addNodes(leftChild);
    }
    Integer right = rightNode.getValue(i);
    if (right != -1) {
        Node rightChild = new Node().setId(String.valueOf(right + 1)).setPredicate(rightPredicate);
        encodeNode(rightChild, right, tree, c_splits, schema);
        node.addNodes(rightChild);
    }
}
Also used : Node(org.dmg.pmml.tree.Node) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) ContinuousFeature(org.jpmml.converter.ContinuousFeature)

Example 27 with Feature

use of org.jpmml.converter.Feature in project jpmml-r by jpmml.

the class BinaryTreeConverter method encodeNode.

private void encodeNode(Node node, RGenericVector tree, Schema schema) {
    RIntegerVector nodeId = (RIntegerVector) tree.getValue("nodeID");
    RBooleanVector terminal = (RBooleanVector) tree.getValue("terminal");
    RGenericVector psplit = (RGenericVector) tree.getValue("psplit");
    RGenericVector ssplits = (RGenericVector) tree.getValue("ssplits");
    RDoubleVector prediction = (RDoubleVector) tree.getValue("prediction");
    RGenericVector left = (RGenericVector) tree.getValue("left");
    RGenericVector right = (RGenericVector) tree.getValue("right");
    node.setId(String.valueOf(nodeId.asScalar()));
    if ((Boolean.TRUE).equals(terminal.asScalar())) {
        node = encodeScore(node, prediction, schema);
        return;
    }
    RNumberVector<?> splitpoint = (RNumberVector<?>) psplit.getValue("splitpoint");
    RStringVector variableName = (RStringVector) psplit.getValue("variableName");
    if (ssplits.size() > 0) {
        throw new IllegalArgumentException();
    }
    Predicate leftPredicate;
    Predicate rightPredicate;
    FieldName name = FieldName.create(variableName.asScalar());
    Integer index = this.featureIndexes.get(name);
    if (index == null) {
        throw new IllegalArgumentException();
    }
    Feature feature = schema.getFeature(index);
    if (feature instanceof CategoricalFeature) {
        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
        List<String> values = categoricalFeature.getValues();
        List<Integer> splitValues = (List<Integer>) splitpoint.getValues();
        leftPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, true));
        rightPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, false));
    } else {
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        String value = ValueUtil.formatValue((Double) splitpoint.asScalar());
        leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
        rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
    }
    Node leftChild = new Node().setPredicate(leftPredicate);
    encodeNode(leftChild, left, schema);
    Node rightChild = new Node().setPredicate(rightPredicate);
    encodeNode(rightChild, right, schema);
    node.addNodes(leftChild, rightChild);
}
Also used : Node(org.dmg.pmml.tree.Node) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) ContinuousFeature(org.jpmml.converter.ContinuousFeature) ArrayList(java.util.ArrayList) List(java.util.List) FieldName(org.dmg.pmml.FieldName)

Example 28 with Feature

use of org.jpmml.converter.Feature in project jpmml-r by jpmml.

the class EarthConverter method encodeSchema.

@Override
public void encodeSchema(RExpEncoder encoder) {
    RGenericVector earth = getObject();
    RDoubleVector dirs = (RDoubleVector) earth.getValue("dirs");
    RDoubleVector cuts = (RDoubleVector) earth.getValue("cuts");
    RDoubleVector selectedTerms = (RDoubleVector) earth.getValue("selected.terms");
    RDoubleVector coefficients = (RDoubleVector) earth.getValue("coefficients");
    RExp terms = earth.getValue("terms");
    final RGenericVector xlevels;
    try {
        xlevels = (RGenericVector) earth.getValue("xlevels");
    } catch (IllegalArgumentException iae) {
        throw new IllegalArgumentException("No variable levels information. Please initialize the \'xlevels\' element", iae);
    }
    RStringVector dirsRows = dirs.dimnames(0);
    RStringVector dirsColumns = dirs.dimnames(1);
    RStringVector cutsRows = cuts.dimnames(0);
    RStringVector cutsColumns = cuts.dimnames(1);
    if (!(dirsRows.getValues()).equals(cutsRows.getValues()) || !(dirsColumns.getValues()).equals(cutsColumns.getValues())) {
        throw new IllegalArgumentException();
    }
    int rows = dirsRows.size();
    int columns = dirsColumns.size();
    List<String> predictorNames = dirsColumns.getValues();
    FormulaContext context = new FormulaContext() {

        @Override
        public List<String> getCategories(String variable) {
            if (xlevels.hasValue(variable)) {
                RStringVector levels = (RStringVector) xlevels.getValue(variable);
                return levels.getValues();
            }
            return null;
        }

        @Override
        public RGenericVector getData() {
            return null;
        }
    };
    Formula formula = FormulaUtil.createFormula(terms, context, encoder);
    // Dependent variable
    {
        RStringVector yNames = coefficients.dimnames(1);
        FieldName name = FieldName.create(yNames.asScalar());
        DataField dataField = (DataField) encoder.getField(name);
        encoder.setLabel(dataField);
    }
    // Independent variables
    for (int i = 1; i < selectedTerms.size(); i++) {
        int termIndex = ValueUtil.asInt(selectedTerms.getValue(i)) - 1;
        List<Double> dirsRow = FortranMatrixUtil.getRow(dirs.getValues(), rows, columns, termIndex);
        List<Double> cutsRow = FortranMatrixUtil.getRow(cuts.getValues(), rows, columns, termIndex);
        List<Feature> features = new ArrayList<>();
        predictors: for (int j = 0; j < predictorNames.size(); j++) {
            String predictorName = predictorNames.get(j);
            int dir = ValueUtil.asInt(dirsRow.get(j));
            double cut = cutsRow.get(j);
            if (dir == 0) {
                continue predictors;
            }
            Feature feature = formula.resolveFeature(predictorName);
            switch(dir) {
                case -1:
                case 1:
                    {
                        feature = feature.toContinuousFeature();
                        FieldName name = FieldName.create(formatHingeFunction(dir, feature, cut));
                        DerivedField derivedField = encoder.getDerivedField(name);
                        if (derivedField == null) {
                            Apply apply = createHingeFunction(dir, feature, cut);
                            derivedField = encoder.createDerivedField(name, OpType.CONTINUOUS, DataType.DOUBLE, apply);
                        }
                        feature = new ContinuousFeature(encoder, derivedField);
                    }
                    break;
                case 2:
                    break;
                default:
                    throw new IllegalArgumentException();
            }
            features.add(feature);
        }
        Feature feature;
        if (features.size() == 1) {
            feature = features.get(0);
        } else if (features.size() > 1) {
            feature = new InteractionFeature(encoder, FieldName.create(dirsRows.getValue(i)), DataType.DOUBLE, features);
        } else {
            throw new IllegalArgumentException();
        }
        encoder.addFeature(feature);
    }
}
Also used : InteractionFeature(org.jpmml.converter.InteractionFeature) Apply(org.dmg.pmml.Apply) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) InteractionFeature(org.jpmml.converter.InteractionFeature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) DataField(org.dmg.pmml.DataField) FieldName(org.dmg.pmml.FieldName) DerivedField(org.dmg.pmml.DerivedField)

Example 29 with Feature

use of org.jpmml.converter.Feature in project jpmml-r by jpmml.

the class Formula method addField.

public void addField(Field<?> field) {
    RExpEncoder encoder = getEncoder();
    Feature feature = new ContinuousFeature(encoder, field);
    if (field instanceof DerivedField) {
        DerivedField derivedField = (DerivedField) field;
        Expression expression = derivedField.getExpression();
        if (expression instanceof Apply) {
            Apply apply = (Apply) expression;
            if (checkApply(apply, "pow", FieldRef.class, Constant.class)) {
                List<Expression> expressions = apply.getExpressions();
                FieldRef fieldRef = (FieldRef) expressions.get(0);
                Constant constant = (Constant) expressions.get(1);
                try {
                    int power = Integer.parseInt(constant.getValue());
                    feature = new PowerFeature(encoder, fieldRef.getField(), DataType.DOUBLE, power);
                } catch (NumberFormatException nfe) {
                // Ignored
                }
            }
        }
    }
    putFeature(field.getName(), feature);
    this.fields.add(field);
}
Also used : PowerFeature(org.jpmml.converter.PowerFeature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) FieldRef(org.dmg.pmml.FieldRef) Expression(org.dmg.pmml.Expression) Apply(org.dmg.pmml.Apply) Constant(org.dmg.pmml.Constant) Feature(org.jpmml.converter.Feature) PowerFeature(org.jpmml.converter.PowerFeature) BinaryFeature(org.jpmml.converter.BinaryFeature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) InteractionFeature(org.jpmml.converter.InteractionFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) DerivedField(org.dmg.pmml.DerivedField)

Example 30 with Feature

use of org.jpmml.converter.Feature in project jpmml-r by jpmml.

the class Formula method getCoefficient.

public Double getCoefficient(Feature feature, RDoubleVector coefficients) {
    FieldName name = feature.getName();
    if (feature instanceof HasDerivedName) {
        BiMap<Feature, FieldName> inverseFeatures = this.features.inverse();
        name = inverseFeatures.get(feature);
    }
    return coefficients.getValue(name.getValue());
}
Also used : FieldName(org.dmg.pmml.FieldName) Feature(org.jpmml.converter.Feature) PowerFeature(org.jpmml.converter.PowerFeature) BinaryFeature(org.jpmml.converter.BinaryFeature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) InteractionFeature(org.jpmml.converter.InteractionFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) HasDerivedName(org.jpmml.converter.HasDerivedName)

Aggregations

Feature (org.jpmml.converter.Feature)53 ContinuousFeature (org.jpmml.converter.ContinuousFeature)30 ArrayList (java.util.ArrayList)27 CategoricalFeature (org.jpmml.converter.CategoricalFeature)19 DerivedField (org.dmg.pmml.DerivedField)14 DataField (org.dmg.pmml.DataField)13 FieldName (org.dmg.pmml.FieldName)10 Apply (org.dmg.pmml.Apply)9 BooleanFeature (org.jpmml.converter.BooleanFeature)9 BinaryFeature (org.jpmml.converter.BinaryFeature)7 List (java.util.List)6 Expression (org.dmg.pmml.Expression)6 SimplePredicate (org.dmg.pmml.SimplePredicate)6 Vector (org.apache.spark.ml.linalg.Vector)5 Predicate (org.dmg.pmml.Predicate)5 Node (org.dmg.pmml.tree.Node)5 DocumentFeature (org.jpmml.sparkml.DocumentFeature)5 InteractionFeature (org.jpmml.converter.InteractionFeature)4 DocumentBuilder (javax.xml.parsers.DocumentBuilder)3 Transformer (org.apache.spark.ml.Transformer)3