Search in sources :

Example 1 with BooleanFeature

use of org.jpmml.converter.BooleanFeature in project jpmml-r by jpmml.

the class RandomForestConverter method encodeNode.

private <P extends Number> void encodeNode(Node node, int i, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<? extends Number> bestvar, List<Double> xbestsplit, List<P> nodepred, Schema schema) {
    Predicate leftPredicate;
    Predicate rightPredicate;
    int var = ValueUtil.asInt(bestvar.get(i));
    if (var != 0) {
        Feature feature = schema.getFeature(var - 1);
        Double split = xbestsplit.get(i);
        if (feature instanceof BooleanFeature) {
            BooleanFeature booleanFeature = (BooleanFeature) feature;
            if (split != 0.5d) {
                throw new IllegalArgumentException();
            }
            leftPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
            rightPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
        } else if (feature instanceof CategoricalFeature) {
            CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
            List<String> values = categoricalFeature.getValues();
            leftPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, split, true));
            rightPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, split, false));
        } else {
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            String value = ValueUtil.formatValue(split);
            leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
            rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
        }
    } else {
        P prediction = nodepred.get(i);
        node.setScore(scoreEncoder.encode(prediction));
        return;
    }
    int left = ValueUtil.asInt(leftDaughter.get(i));
    if (left != 0) {
        Node leftChild = new Node().setId(String.valueOf(left)).setPredicate(leftPredicate);
        encodeNode(leftChild, left - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, schema);
        node.addNodes(leftChild);
    }
    int right = ValueUtil.asInt(rightDaughter.get(i));
    if (right != 0) {
        Node rightChild = new Node().setId(String.valueOf(right)).setPredicate(rightPredicate);
        encodeNode(rightChild, right - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, schema);
        node.addNodes(rightChild);
    }
}
Also used : ContinuousFeature(org.jpmml.converter.ContinuousFeature) Node(org.dmg.pmml.tree.Node) ArrayList(java.util.ArrayList) List(java.util.List) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate)

Example 2 with BooleanFeature

use of org.jpmml.converter.BooleanFeature in project jpmml-sparkml by jpmml.

the class TreeModelUtil method encodeNode.

public static Node encodeNode(org.apache.spark.ml.tree.Node node, PredicateManager predicateManager, Map<FieldName, Set<String>> parentFieldValues, MiningFunction miningFunction, Schema schema) {
    if (node instanceof InternalNode) {
        InternalNode internalNode = (InternalNode) node;
        Map<FieldName, Set<String>> leftFieldValues = parentFieldValues;
        Map<FieldName, Set<String>> rightFieldValues = parentFieldValues;
        Predicate leftPredicate;
        Predicate rightPredicate;
        Split split = internalNode.split();
        Feature feature = schema.getFeature(split.featureIndex());
        if (split instanceof ContinuousSplit) {
            ContinuousSplit continuousSplit = (ContinuousSplit) split;
            double threshold = continuousSplit.threshold();
            if (feature instanceof BooleanFeature) {
                BooleanFeature booleanFeature = (BooleanFeature) feature;
                if (threshold != 0.5d) {
                    throw new IllegalArgumentException();
                }
                leftPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
                rightPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
            } else {
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                String value = ValueUtil.formatValue(threshold);
                leftPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
                rightPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
            }
        } else if (split instanceof CategoricalSplit) {
            CategoricalSplit categoricalSplit = (CategoricalSplit) split;
            double[] leftCategories = categoricalSplit.leftCategories();
            double[] rightCategories = categoricalSplit.rightCategories();
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature) feature;
                SimplePredicate.Operator leftOperator;
                SimplePredicate.Operator rightOperator;
                if (Arrays.equals(TRUE, leftCategories) && Arrays.equals(FALSE, rightCategories)) {
                    leftOperator = SimplePredicate.Operator.EQUAL;
                    rightOperator = SimplePredicate.Operator.NOT_EQUAL;
                } else if (Arrays.equals(FALSE, leftCategories) && Arrays.equals(TRUE, rightCategories)) {
                    leftOperator = SimplePredicate.Operator.NOT_EQUAL;
                    rightOperator = SimplePredicate.Operator.EQUAL;
                } else {
                    throw new IllegalArgumentException();
                }
                String value = ValueUtil.formatValue(binaryFeature.getValue());
                leftPredicate = predicateManager.createSimplePredicate(binaryFeature, leftOperator, value);
                rightPredicate = predicateManager.createSimplePredicate(binaryFeature, rightOperator, value);
            } else if (feature instanceof CategoricalFeature) {
                CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
                FieldName name = categoricalFeature.getName();
                List<String> values = categoricalFeature.getValues();
                if (values.size() != (leftCategories.length + rightCategories.length)) {
                    throw new IllegalArgumentException();
                }
                final Set<String> parentValues = parentFieldValues.get(name);
                com.google.common.base.Predicate<String> valueFilter = new com.google.common.base.Predicate<String>() {

                    @Override
                    public boolean apply(String value) {
                        if (parentValues != null) {
                            return parentValues.contains(value);
                        }
                        return true;
                    }
                };
                List<String> leftValues = selectValues(values, leftCategories, valueFilter);
                List<String> rightValues = selectValues(values, rightCategories, valueFilter);
                leftFieldValues = new HashMap<>(parentFieldValues);
                leftFieldValues.put(name, new HashSet<>(leftValues));
                rightFieldValues = new HashMap<>(parentFieldValues);
                rightFieldValues.put(name, new HashSet<>(rightValues));
                leftPredicate = predicateManager.createSimpleSetPredicate(categoricalFeature, leftValues);
                rightPredicate = predicateManager.createSimpleSetPredicate(categoricalFeature, rightValues);
            } else {
                throw new IllegalArgumentException();
            }
        } else {
            throw new IllegalArgumentException();
        }
        Node result = new Node();
        Node leftChild = encodeNode(internalNode.leftChild(), predicateManager, leftFieldValues, miningFunction, schema).setPredicate(leftPredicate);
        Node rightChild = encodeNode(internalNode.rightChild(), predicateManager, rightFieldValues, miningFunction, schema).setPredicate(rightPredicate);
        result.addNodes(leftChild, rightChild);
        return result;
    } else if (node instanceof LeafNode) {
        LeafNode leafNode = (LeafNode) node;
        Node result = new Node();
        switch(miningFunction) {
            case REGRESSION:
                {
                    String score = ValueUtil.formatValue(node.prediction());
                    result.setScore(score);
                }
                break;
            case CLASSIFICATION:
                {
                    CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
                    int index = ValueUtil.asInt(node.prediction());
                    result.setScore(categoricalLabel.getValue(index));
                    ImpurityCalculator impurityCalculator = node.impurityStats();
                    result.setRecordCount((double) impurityCalculator.count());
                    double[] stats = impurityCalculator.stats();
                    for (int i = 0; i < stats.length; i++) {
                        ScoreDistribution scoreDistribution = new ScoreDistribution(categoricalLabel.getValue(i), stats[i]);
                        result.addScoreDistributions(scoreDistribution);
                    }
                }
                break;
            default:
                throw new UnsupportedOperationException();
        }
        return result;
    } else {
        throw new IllegalArgumentException();
    }
}
Also used : HashSet(java.util.HashSet) Set(java.util.Set) InternalNode(org.apache.spark.ml.tree.InternalNode) Node(org.dmg.pmml.tree.Node) LeafNode(org.apache.spark.ml.tree.LeafNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) BinaryFeature(org.jpmml.converter.BinaryFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) ContinuousSplit(org.apache.spark.ml.tree.ContinuousSplit) LeafNode(org.apache.spark.ml.tree.LeafNode) FieldName(org.dmg.pmml.FieldName) BinaryFeature(org.jpmml.converter.BinaryFeature) ScoreDistribution(org.dmg.pmml.ScoreDistribution) ContinuousFeature(org.jpmml.converter.ContinuousFeature) ImpurityCalculator(org.apache.spark.mllib.tree.impurity.ImpurityCalculator) CategoricalSplit(org.apache.spark.ml.tree.CategoricalSplit) CategoricalLabel(org.jpmml.converter.CategoricalLabel) InternalNode(org.apache.spark.ml.tree.InternalNode) Split(org.apache.spark.ml.tree.Split) ContinuousSplit(org.apache.spark.ml.tree.ContinuousSplit) CategoricalSplit(org.apache.spark.ml.tree.CategoricalSplit)

Example 3 with BooleanFeature

use of org.jpmml.converter.BooleanFeature in project jpmml-sparkml by jpmml.

the class SparkMLEncoder method getFeatures.

public List<Feature> getFeatures(String column) {
    List<Feature> features = this.columnFeatures.get(column);
    if (features == null) {
        FieldName name = FieldName.create(column);
        DataField dataField = getDataField(name);
        if (dataField == null) {
            dataField = createDataField(name);
        }
        Feature feature;
        DataType dataType = dataField.getDataType();
        switch(dataType) {
            case STRING:
                feature = new WildcardFeature(this, dataField);
                break;
            case INTEGER:
            case DOUBLE:
                feature = new ContinuousFeature(this, dataField);
                break;
            case BOOLEAN:
                feature = new BooleanFeature(this, dataField);
                break;
            default:
                throw new IllegalArgumentException("Data type " + dataType + " is not supported");
        }
        return Collections.singletonList(feature);
    }
    return features;
}
Also used : ContinuousFeature(org.jpmml.converter.ContinuousFeature) DataField(org.dmg.pmml.DataField) DataType(org.dmg.pmml.DataType) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) WildcardFeature(org.jpmml.converter.WildcardFeature) FieldName(org.dmg.pmml.FieldName) BooleanFeature(org.jpmml.converter.BooleanFeature) WildcardFeature(org.jpmml.converter.WildcardFeature)

Example 4 with BooleanFeature

use of org.jpmml.converter.BooleanFeature in project jpmml-r by jpmml.

the class Formula method addField.

public void addField(Field<?> field, List<String> categoryNames, List<String> categoryValues) {
    RExpEncoder encoder = getEncoder();
    if (categoryNames.size() != categoryValues.size()) {
        throw new IllegalArgumentException();
    }
    CategoricalFeature categoricalFeature;
    if ((DataType.BOOLEAN).equals(field.getDataType()) && (categoryValues.size() == 2) && ("false").equals(categoryValues.get(0)) && ("true").equals(categoryValues.get(1))) {
        categoricalFeature = new BooleanFeature(encoder, field);
    } else {
        categoricalFeature = new CategoricalFeature(encoder, field, categoryValues);
    }
    putFeature(field.getName(), categoricalFeature);
    for (int i = 0; i < categoryNames.size(); i++) {
        String categoryName = categoryNames.get(i);
        String categoryValue = categoryValues.get(i);
        BinaryFeature binaryFeature = new BinaryFeature(encoder, field, categoryValue);
        putFeature(FieldName.create((field.getName()).getValue() + categoryName), binaryFeature);
    }
    this.fields.add(field);
}
Also used : BinaryFeature(org.jpmml.converter.BinaryFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature)

Aggregations

BooleanFeature (org.jpmml.converter.BooleanFeature)4 CategoricalFeature (org.jpmml.converter.CategoricalFeature)3 ContinuousFeature (org.jpmml.converter.ContinuousFeature)3 Feature (org.jpmml.converter.Feature)3 FieldName (org.dmg.pmml.FieldName)2 Predicate (org.dmg.pmml.Predicate)2 SimplePredicate (org.dmg.pmml.SimplePredicate)2 Node (org.dmg.pmml.tree.Node)2 BinaryFeature (org.jpmml.converter.BinaryFeature)2 ArrayList (java.util.ArrayList)1 HashSet (java.util.HashSet)1 List (java.util.List)1 Set (java.util.Set)1 CategoricalSplit (org.apache.spark.ml.tree.CategoricalSplit)1 ContinuousSplit (org.apache.spark.ml.tree.ContinuousSplit)1 InternalNode (org.apache.spark.ml.tree.InternalNode)1 LeafNode (org.apache.spark.ml.tree.LeafNode)1 Split (org.apache.spark.ml.tree.Split)1 ImpurityCalculator (org.apache.spark.mllib.tree.impurity.ImpurityCalculator)1 DataField (org.dmg.pmml.DataField)1