Search in sources :

Example 1 with BinaryFeature

use of org.jpmml.converter.BinaryFeature in project jpmml-sparkml by jpmml.

the class TreeModelUtil method encodeNode.

public static Node encodeNode(org.apache.spark.ml.tree.Node node, PredicateManager predicateManager, Map<FieldName, Set<String>> parentFieldValues, MiningFunction miningFunction, Schema schema) {
    if (node instanceof InternalNode) {
        InternalNode internalNode = (InternalNode) node;
        Map<FieldName, Set<String>> leftFieldValues = parentFieldValues;
        Map<FieldName, Set<String>> rightFieldValues = parentFieldValues;
        Predicate leftPredicate;
        Predicate rightPredicate;
        Split split = internalNode.split();
        Feature feature = schema.getFeature(split.featureIndex());
        if (split instanceof ContinuousSplit) {
            ContinuousSplit continuousSplit = (ContinuousSplit) split;
            double threshold = continuousSplit.threshold();
            if (feature instanceof BooleanFeature) {
                BooleanFeature booleanFeature = (BooleanFeature) feature;
                if (threshold != 0.5d) {
                    throw new IllegalArgumentException();
                }
                leftPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
                rightPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
            } else {
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                String value = ValueUtil.formatValue(threshold);
                leftPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
                rightPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
            }
        } else if (split instanceof CategoricalSplit) {
            CategoricalSplit categoricalSplit = (CategoricalSplit) split;
            double[] leftCategories = categoricalSplit.leftCategories();
            double[] rightCategories = categoricalSplit.rightCategories();
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature) feature;
                SimplePredicate.Operator leftOperator;
                SimplePredicate.Operator rightOperator;
                if (Arrays.equals(TRUE, leftCategories) && Arrays.equals(FALSE, rightCategories)) {
                    leftOperator = SimplePredicate.Operator.EQUAL;
                    rightOperator = SimplePredicate.Operator.NOT_EQUAL;
                } else if (Arrays.equals(FALSE, leftCategories) && Arrays.equals(TRUE, rightCategories)) {
                    leftOperator = SimplePredicate.Operator.NOT_EQUAL;
                    rightOperator = SimplePredicate.Operator.EQUAL;
                } else {
                    throw new IllegalArgumentException();
                }
                String value = ValueUtil.formatValue(binaryFeature.getValue());
                leftPredicate = predicateManager.createSimplePredicate(binaryFeature, leftOperator, value);
                rightPredicate = predicateManager.createSimplePredicate(binaryFeature, rightOperator, value);
            } else if (feature instanceof CategoricalFeature) {
                CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
                FieldName name = categoricalFeature.getName();
                List<String> values = categoricalFeature.getValues();
                if (values.size() != (leftCategories.length + rightCategories.length)) {
                    throw new IllegalArgumentException();
                }
                final Set<String> parentValues = parentFieldValues.get(name);
                com.google.common.base.Predicate<String> valueFilter = new com.google.common.base.Predicate<String>() {

                    @Override
                    public boolean apply(String value) {
                        if (parentValues != null) {
                            return parentValues.contains(value);
                        }
                        return true;
                    }
                };
                List<String> leftValues = selectValues(values, leftCategories, valueFilter);
                List<String> rightValues = selectValues(values, rightCategories, valueFilter);
                leftFieldValues = new HashMap<>(parentFieldValues);
                leftFieldValues.put(name, new HashSet<>(leftValues));
                rightFieldValues = new HashMap<>(parentFieldValues);
                rightFieldValues.put(name, new HashSet<>(rightValues));
                leftPredicate = predicateManager.createSimpleSetPredicate(categoricalFeature, leftValues);
                rightPredicate = predicateManager.createSimpleSetPredicate(categoricalFeature, rightValues);
            } else {
                throw new IllegalArgumentException();
            }
        } else {
            throw new IllegalArgumentException();
        }
        Node result = new Node();
        Node leftChild = encodeNode(internalNode.leftChild(), predicateManager, leftFieldValues, miningFunction, schema).setPredicate(leftPredicate);
        Node rightChild = encodeNode(internalNode.rightChild(), predicateManager, rightFieldValues, miningFunction, schema).setPredicate(rightPredicate);
        result.addNodes(leftChild, rightChild);
        return result;
    } else if (node instanceof LeafNode) {
        LeafNode leafNode = (LeafNode) node;
        Node result = new Node();
        switch(miningFunction) {
            case REGRESSION:
                {
                    String score = ValueUtil.formatValue(node.prediction());
                    result.setScore(score);
                }
                break;
            case CLASSIFICATION:
                {
                    CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
                    int index = ValueUtil.asInt(node.prediction());
                    result.setScore(categoricalLabel.getValue(index));
                    ImpurityCalculator impurityCalculator = node.impurityStats();
                    result.setRecordCount((double) impurityCalculator.count());
                    double[] stats = impurityCalculator.stats();
                    for (int i = 0; i < stats.length; i++) {
                        ScoreDistribution scoreDistribution = new ScoreDistribution(categoricalLabel.getValue(i), stats[i]);
                        result.addScoreDistributions(scoreDistribution);
                    }
                }
                break;
            default:
                throw new UnsupportedOperationException();
        }
        return result;
    } else {
        throw new IllegalArgumentException();
    }
}
Also used : HashSet(java.util.HashSet) Set(java.util.Set) InternalNode(org.apache.spark.ml.tree.InternalNode) Node(org.dmg.pmml.tree.Node) LeafNode(org.apache.spark.ml.tree.LeafNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) BinaryFeature(org.jpmml.converter.BinaryFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) ContinuousSplit(org.apache.spark.ml.tree.ContinuousSplit) LeafNode(org.apache.spark.ml.tree.LeafNode) FieldName(org.dmg.pmml.FieldName) BinaryFeature(org.jpmml.converter.BinaryFeature) ScoreDistribution(org.dmg.pmml.ScoreDistribution) ContinuousFeature(org.jpmml.converter.ContinuousFeature) ImpurityCalculator(org.apache.spark.mllib.tree.impurity.ImpurityCalculator) CategoricalSplit(org.apache.spark.ml.tree.CategoricalSplit) CategoricalLabel(org.jpmml.converter.CategoricalLabel) InternalNode(org.apache.spark.ml.tree.InternalNode) Split(org.apache.spark.ml.tree.Split) ContinuousSplit(org.apache.spark.ml.tree.ContinuousSplit) CategoricalSplit(org.apache.spark.ml.tree.CategoricalSplit)

Example 2 with BinaryFeature

use of org.jpmml.converter.BinaryFeature in project jpmml-sparkml by jpmml.

the class OneHotEncoderModelConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    OneHotEncoderModel transformer = getTransformer();
    String[] inputCols = transformer.getInputCols();
    boolean dropLast = transformer.getDropLast();
    List<Feature> result = new ArrayList<>();
    for (int i = 0; i < inputCols.length; i++) {
        CategoricalFeature categoricalFeature = (CategoricalFeature) encoder.getOnlyFeature(inputCols[i]);
        List<String> values = categoricalFeature.getValues();
        if (dropLast) {
            values = values.subList(0, values.size() - 1);
        }
        List<BinaryFeature> binaryFeatures = new ArrayList<>();
        for (String value : values) {
            binaryFeatures.add(new BinaryFeature(encoder, categoricalFeature.getName(), DataType.STRING, value));
        }
        result.add(new BinarizedCategoricalFeature(encoder, categoricalFeature.getName(), categoricalFeature.getDataType(), binaryFeatures));
    }
    return result;
}
Also used : ArrayList(java.util.ArrayList) BinarizedCategoricalFeature(org.jpmml.sparkml.BinarizedCategoricalFeature) BinaryFeature(org.jpmml.converter.BinaryFeature) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) BinaryFeature(org.jpmml.converter.BinaryFeature) BinarizedCategoricalFeature(org.jpmml.sparkml.BinarizedCategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) BinarizedCategoricalFeature(org.jpmml.sparkml.BinarizedCategoricalFeature) OneHotEncoderModel(org.apache.spark.ml.feature.OneHotEncoderModel)

Example 3 with BinaryFeature

use of org.jpmml.converter.BinaryFeature in project jpmml-r by jpmml.

the class Formula method addField.

public void addField(Field<?> field, List<String> categoryNames, List<String> categoryValues) {
    RExpEncoder encoder = getEncoder();
    if (categoryNames.size() != categoryValues.size()) {
        throw new IllegalArgumentException();
    }
    CategoricalFeature categoricalFeature;
    if ((DataType.BOOLEAN).equals(field.getDataType()) && (categoryValues.size() == 2) && ("false").equals(categoryValues.get(0)) && ("true").equals(categoryValues.get(1))) {
        categoricalFeature = new BooleanFeature(encoder, field);
    } else {
        categoricalFeature = new CategoricalFeature(encoder, field, categoryValues);
    }
    putFeature(field.getName(), categoricalFeature);
    for (int i = 0; i < categoryNames.size(); i++) {
        String categoryName = categoryNames.get(i);
        String categoryValue = categoryValues.get(i);
        BinaryFeature binaryFeature = new BinaryFeature(encoder, field, categoryValue);
        putFeature(FieldName.create((field.getName()).getValue() + categoryName), binaryFeature);
    }
    this.fields.add(field);
}
Also used : BinaryFeature(org.jpmml.converter.BinaryFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature)

Example 4 with BinaryFeature

use of org.jpmml.converter.BinaryFeature in project jpmml-r by jpmml.

the class ScorecardConverter method encodeModel.

@Override
public Scorecard encodeModel(Schema schema) {
    RGenericVector glm = getObject();
    RDoubleVector coefficients = (RDoubleVector) glm.getValue("coefficients");
    RGenericVector family = (RGenericVector) glm.getValue("family");
    RGenericVector scConf;
    try {
        scConf = (RGenericVector) glm.getValue("sc.conf");
    } catch (IllegalArgumentException iae) {
        throw new IllegalArgumentException("No scorecard configuration information. Please initialize the \'sc.conf\' element", iae);
    }
    Double intercept = coefficients.getValue(LMConverter.INTERCEPT, true);
    List<? extends Feature> features = schema.getFeatures();
    if (coefficients.size() != (features.size() + (intercept != null ? 1 : 0))) {
        throw new IllegalArgumentException();
    }
    RNumberVector<?> odds = (RNumberVector<?>) scConf.getValue("odds");
    RNumberVector<?> basePoints = (RNumberVector<?>) scConf.getValue("base_points");
    RNumberVector<?> pdo = (RNumberVector<?>) scConf.getValue("pdo");
    double factor = (pdo.asScalar()).doubleValue() / Math.log(2);
    Map<FieldName, Characteristic> fieldCharacteristics = new LinkedHashMap<>();
    for (Feature feature : features) {
        FieldName name = feature.getName();
        if (!(feature instanceof BinaryFeature)) {
            throw new IllegalArgumentException();
        }
        Double coefficient = getFeatureCoefficient(feature, coefficients);
        Characteristic characteristic = fieldCharacteristics.get(name);
        if (characteristic == null) {
            characteristic = new Characteristic().setName(FeatureUtil.createName("score", feature));
            fieldCharacteristics.put(name, characteristic);
        }
        BinaryFeature binaryFeature = (BinaryFeature) feature;
        SimplePredicate simplePredicate = new SimplePredicate().setField(binaryFeature.getName()).setOperator(SimplePredicate.Operator.EQUAL).setValue(binaryFeature.getValue());
        Attribute attribute = new Attribute().setPartialScore(formatScore(-1d * coefficient * factor)).setPredicate(simplePredicate);
        characteristic.addAttributes(attribute);
    }
    Characteristics characteristics = new Characteristics();
    Collection<Map.Entry<FieldName, Characteristic>> entries = fieldCharacteristics.entrySet();
    for (Map.Entry<FieldName, Characteristic> entry : entries) {
        Characteristic characteristic = entry.getValue();
        Attribute attribute = new Attribute().setPartialScore(0d).setPredicate(new True());
        characteristic.addAttributes(attribute);
        characteristics.addCharacteristics(characteristic);
    }
    Scorecard scorecard = new Scorecard(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), characteristics).setInitialScore(formatScore((basePoints.asScalar()).doubleValue() - Math.log((odds.asScalar()).doubleValue()) * factor - (intercept != null ? intercept * factor : 0))).setUseReasonCodes(false);
    return scorecard;
}
Also used : Attribute(org.dmg.pmml.scorecard.Attribute) Characteristic(org.dmg.pmml.scorecard.Characteristic) True(org.dmg.pmml.True) BinaryFeature(org.jpmml.converter.BinaryFeature) Feature(org.jpmml.converter.Feature) BinaryFeature(org.jpmml.converter.BinaryFeature) SimplePredicate(org.dmg.pmml.SimplePredicate) LinkedHashMap(java.util.LinkedHashMap) Characteristics(org.dmg.pmml.scorecard.Characteristics) Scorecard(org.dmg.pmml.scorecard.Scorecard) FieldName(org.dmg.pmml.FieldName) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map)

Example 5 with BinaryFeature

use of org.jpmml.converter.BinaryFeature in project jpmml-sparkml by jpmml.

the class OneHotEncoderConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    OneHotEncoder transformer = getTransformer();
    boolean dropLast = true;
    Option<Object> dropLastOption = transformer.get(transformer.dropLast());
    if (dropLastOption.isDefined()) {
        dropLast = (Boolean) dropLastOption.get();
    }
    CategoricalFeature categoricalFeature = (CategoricalFeature) encoder.getOnlyFeature(transformer.getInputCol());
    List<String> values = categoricalFeature.getValues();
    if (dropLast) {
        values = values.subList(0, values.size() - 1);
    }
    List<Feature> result = new ArrayList<>();
    for (String value : values) {
        result.add(new BinaryFeature(encoder, categoricalFeature.getName(), DataType.STRING, value));
    }
    return result;
}
Also used : OneHotEncoder(org.apache.spark.ml.feature.OneHotEncoder) ArrayList(java.util.ArrayList) BinaryFeature(org.jpmml.converter.BinaryFeature) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) BinaryFeature(org.jpmml.converter.BinaryFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature)

Aggregations

BinaryFeature (org.jpmml.converter.BinaryFeature)5 CategoricalFeature (org.jpmml.converter.CategoricalFeature)4 Feature (org.jpmml.converter.Feature)4 ArrayList (java.util.ArrayList)2 FieldName (org.dmg.pmml.FieldName)2 SimplePredicate (org.dmg.pmml.SimplePredicate)2 BooleanFeature (org.jpmml.converter.BooleanFeature)2 HashSet (java.util.HashSet)1 LinkedHashMap (java.util.LinkedHashMap)1 Map (java.util.Map)1 Set (java.util.Set)1 OneHotEncoder (org.apache.spark.ml.feature.OneHotEncoder)1 OneHotEncoderModel (org.apache.spark.ml.feature.OneHotEncoderModel)1 CategoricalSplit (org.apache.spark.ml.tree.CategoricalSplit)1 ContinuousSplit (org.apache.spark.ml.tree.ContinuousSplit)1 InternalNode (org.apache.spark.ml.tree.InternalNode)1 LeafNode (org.apache.spark.ml.tree.LeafNode)1 Split (org.apache.spark.ml.tree.Split)1 ImpurityCalculator (org.apache.spark.mllib.tree.impurity.ImpurityCalculator)1 Predicate (org.dmg.pmml.Predicate)1