Search in sources :

Example 6 with ContinuousFeature

use of org.jpmml.converter.ContinuousFeature in project jpmml-r by jpmml.

the class PreProcessEncoder method filter.

private Schema filter(Schema schema) {
    Function<Feature, Feature> function = new Function<Feature, Feature>() {

        @Override
        public Feature apply(Feature feature) {
            Expression expression = encodeExpression(feature);
            if (expression == null) {
                return feature;
            }
            DerivedField derivedField = createDerivedField(FeatureUtil.createName("preProcess", feature), OpType.CONTINUOUS, DataType.DOUBLE, expression);
            return new ContinuousFeature(PreProcessEncoder.this, derivedField);
        }
    };
    return schema.toTransformedSchema(function);
}
Also used : Function(java.util.function.Function) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Expression(org.dmg.pmml.Expression) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) DerivedField(org.dmg.pmml.DerivedField)

Example 7 with ContinuousFeature

use of org.jpmml.converter.ContinuousFeature in project jpmml-sparkml by jpmml.

the class TreeModelUtil method encodeNode.

public static Node encodeNode(org.apache.spark.ml.tree.Node node, PredicateManager predicateManager, Map<FieldName, Set<String>> parentFieldValues, MiningFunction miningFunction, Schema schema) {
    if (node instanceof InternalNode) {
        InternalNode internalNode = (InternalNode) node;
        Map<FieldName, Set<String>> leftFieldValues = parentFieldValues;
        Map<FieldName, Set<String>> rightFieldValues = parentFieldValues;
        Predicate leftPredicate;
        Predicate rightPredicate;
        Split split = internalNode.split();
        Feature feature = schema.getFeature(split.featureIndex());
        if (split instanceof ContinuousSplit) {
            ContinuousSplit continuousSplit = (ContinuousSplit) split;
            double threshold = continuousSplit.threshold();
            if (feature instanceof BooleanFeature) {
                BooleanFeature booleanFeature = (BooleanFeature) feature;
                if (threshold != 0.5d) {
                    throw new IllegalArgumentException();
                }
                leftPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
                rightPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
            } else {
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                String value = ValueUtil.formatValue(threshold);
                leftPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
                rightPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
            }
        } else if (split instanceof CategoricalSplit) {
            CategoricalSplit categoricalSplit = (CategoricalSplit) split;
            double[] leftCategories = categoricalSplit.leftCategories();
            double[] rightCategories = categoricalSplit.rightCategories();
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature) feature;
                SimplePredicate.Operator leftOperator;
                SimplePredicate.Operator rightOperator;
                if (Arrays.equals(TRUE, leftCategories) && Arrays.equals(FALSE, rightCategories)) {
                    leftOperator = SimplePredicate.Operator.EQUAL;
                    rightOperator = SimplePredicate.Operator.NOT_EQUAL;
                } else if (Arrays.equals(FALSE, leftCategories) && Arrays.equals(TRUE, rightCategories)) {
                    leftOperator = SimplePredicate.Operator.NOT_EQUAL;
                    rightOperator = SimplePredicate.Operator.EQUAL;
                } else {
                    throw new IllegalArgumentException();
                }
                String value = ValueUtil.formatValue(binaryFeature.getValue());
                leftPredicate = predicateManager.createSimplePredicate(binaryFeature, leftOperator, value);
                rightPredicate = predicateManager.createSimplePredicate(binaryFeature, rightOperator, value);
            } else if (feature instanceof CategoricalFeature) {
                CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
                FieldName name = categoricalFeature.getName();
                List<String> values = categoricalFeature.getValues();
                if (values.size() != (leftCategories.length + rightCategories.length)) {
                    throw new IllegalArgumentException();
                }
                final Set<String> parentValues = parentFieldValues.get(name);
                com.google.common.base.Predicate<String> valueFilter = new com.google.common.base.Predicate<String>() {

                    @Override
                    public boolean apply(String value) {
                        if (parentValues != null) {
                            return parentValues.contains(value);
                        }
                        return true;
                    }
                };
                List<String> leftValues = selectValues(values, leftCategories, valueFilter);
                List<String> rightValues = selectValues(values, rightCategories, valueFilter);
                leftFieldValues = new HashMap<>(parentFieldValues);
                leftFieldValues.put(name, new HashSet<>(leftValues));
                rightFieldValues = new HashMap<>(parentFieldValues);
                rightFieldValues.put(name, new HashSet<>(rightValues));
                leftPredicate = predicateManager.createSimpleSetPredicate(categoricalFeature, leftValues);
                rightPredicate = predicateManager.createSimpleSetPredicate(categoricalFeature, rightValues);
            } else {
                throw new IllegalArgumentException();
            }
        } else {
            throw new IllegalArgumentException();
        }
        Node result = new Node();
        Node leftChild = encodeNode(internalNode.leftChild(), predicateManager, leftFieldValues, miningFunction, schema).setPredicate(leftPredicate);
        Node rightChild = encodeNode(internalNode.rightChild(), predicateManager, rightFieldValues, miningFunction, schema).setPredicate(rightPredicate);
        result.addNodes(leftChild, rightChild);
        return result;
    } else if (node instanceof LeafNode) {
        LeafNode leafNode = (LeafNode) node;
        Node result = new Node();
        switch(miningFunction) {
            case REGRESSION:
                {
                    String score = ValueUtil.formatValue(node.prediction());
                    result.setScore(score);
                }
                break;
            case CLASSIFICATION:
                {
                    CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
                    int index = ValueUtil.asInt(node.prediction());
                    result.setScore(categoricalLabel.getValue(index));
                    ImpurityCalculator impurityCalculator = node.impurityStats();
                    result.setRecordCount((double) impurityCalculator.count());
                    double[] stats = impurityCalculator.stats();
                    for (int i = 0; i < stats.length; i++) {
                        ScoreDistribution scoreDistribution = new ScoreDistribution(categoricalLabel.getValue(i), stats[i]);
                        result.addScoreDistributions(scoreDistribution);
                    }
                }
                break;
            default:
                throw new UnsupportedOperationException();
        }
        return result;
    } else {
        throw new IllegalArgumentException();
    }
}
Also used : HashSet(java.util.HashSet) Set(java.util.Set) InternalNode(org.apache.spark.ml.tree.InternalNode) Node(org.dmg.pmml.tree.Node) LeafNode(org.apache.spark.ml.tree.LeafNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) BinaryFeature(org.jpmml.converter.BinaryFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) ContinuousSplit(org.apache.spark.ml.tree.ContinuousSplit) LeafNode(org.apache.spark.ml.tree.LeafNode) FieldName(org.dmg.pmml.FieldName) BinaryFeature(org.jpmml.converter.BinaryFeature) ScoreDistribution(org.dmg.pmml.ScoreDistribution) ContinuousFeature(org.jpmml.converter.ContinuousFeature) ImpurityCalculator(org.apache.spark.mllib.tree.impurity.ImpurityCalculator) CategoricalSplit(org.apache.spark.ml.tree.CategoricalSplit) CategoricalLabel(org.jpmml.converter.CategoricalLabel) InternalNode(org.apache.spark.ml.tree.InternalNode) Split(org.apache.spark.ml.tree.Split) ContinuousSplit(org.apache.spark.ml.tree.ContinuousSplit) CategoricalSplit(org.apache.spark.ml.tree.CategoricalSplit)

Example 8 with ContinuousFeature

use of org.jpmml.converter.ContinuousFeature in project jpmml-sparkml by jpmml.

the class ModelConverter method encodeSchema.

public Schema encodeSchema(SparkMLEncoder encoder) {
    T model = getTransformer();
    Label label = null;
    if (model instanceof HasLabelCol) {
        HasLabelCol hasLabelCol = (HasLabelCol) model;
        String labelCol = hasLabelCol.getLabelCol();
        Feature feature = encoder.getOnlyFeature(labelCol);
        MiningFunction miningFunction = getMiningFunction();
        switch(miningFunction) {
            case CLASSIFICATION:
                {
                    if (feature instanceof CategoricalFeature) {
                        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
                        DataField dataField = encoder.getDataField(categoricalFeature.getName());
                        label = new CategoricalLabel(dataField);
                    } else if (feature instanceof ContinuousFeature) {
                        ContinuousFeature continuousFeature = (ContinuousFeature) feature;
                        int numClasses = 2;
                        if (model instanceof ClassificationModel) {
                            ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
                            numClasses = classificationModel.numClasses();
                        }
                        List<String> categories = new ArrayList<>();
                        for (int i = 0; i < numClasses; i++) {
                            categories.add(String.valueOf(i));
                        }
                        Field<?> field = encoder.toCategorical(continuousFeature.getName(), categories);
                        encoder.putOnlyFeature(labelCol, new CategoricalFeature(encoder, field, categories));
                        label = new CategoricalLabel(field.getName(), field.getDataType(), categories);
                    } else {
                        throw new IllegalArgumentException("Expected a categorical or categorical-like continuous feature, got " + feature);
                    }
                }
                break;
            case REGRESSION:
                {
                    Field<?> field = encoder.toContinuous(feature.getName());
                    field.setDataType(DataType.DOUBLE);
                    label = new ContinuousLabel(field.getName(), field.getDataType());
                }
                break;
            default:
                throw new IllegalArgumentException("Mining function " + miningFunction + " is not supported");
        }
    }
    if (model instanceof ClassificationModel) {
        ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
        CategoricalLabel categoricalLabel = (CategoricalLabel) label;
        int numClasses = classificationModel.numClasses();
        if (numClasses != categoricalLabel.size()) {
            throw new IllegalArgumentException("Expected " + numClasses + " target categories, got " + categoricalLabel.size() + " target categories");
        }
    }
    String featuresCol = model.getFeaturesCol();
    List<Feature> features = encoder.getFeatures(featuresCol);
    if (model instanceof PredictionModel) {
        PredictionModel<?, ?> predictionModel = (PredictionModel<?, ?>) model;
        int numFeatures = predictionModel.numFeatures();
        if (numFeatures != -1 && features.size() != numFeatures) {
            throw new IllegalArgumentException("Expected " + numFeatures + " features, got " + features.size() + " features");
        }
    }
    Schema result = new Schema(label, features);
    return result;
}
Also used : Schema(org.jpmml.converter.Schema) ContinuousLabel(org.jpmml.converter.ContinuousLabel) CategoricalLabel(org.jpmml.converter.CategoricalLabel) Label(org.jpmml.converter.Label) ArrayList(java.util.ArrayList) PredictionModel(org.apache.spark.ml.PredictionModel) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) HasLabelCol(org.apache.spark.ml.param.shared.HasLabelCol) OutputField(org.dmg.pmml.OutputField) Field(org.dmg.pmml.Field) DataField(org.dmg.pmml.DataField) ContinuousFeature(org.jpmml.converter.ContinuousFeature) DataField(org.dmg.pmml.DataField) CategoricalLabel(org.jpmml.converter.CategoricalLabel) MiningFunction(org.dmg.pmml.MiningFunction) ClassificationModel(org.apache.spark.ml.classification.ClassificationModel) ContinuousLabel(org.jpmml.converter.ContinuousLabel)

Example 9 with ContinuousFeature

use of org.jpmml.converter.ContinuousFeature in project jpmml-sparkml by jpmml.

the class RegressionModelConverter method registerOutputFields.

@Override
public List<OutputField> registerOutputFields(Label label, SparkMLEncoder encoder) {
    T model = getTransformer();
    String predictionCol = model.getPredictionCol();
    OutputField predictedField = ModelUtil.createPredictedField(FieldName.create(predictionCol), label.getDataType(), OpType.CONTINUOUS);
    encoder.putOnlyFeature(predictionCol, new ContinuousFeature(encoder, predictedField.getName(), predictedField.getDataType()));
    return Collections.singletonList(predictedField);
}
Also used : ContinuousFeature(org.jpmml.converter.ContinuousFeature) OutputField(org.dmg.pmml.OutputField)

Example 10 with ContinuousFeature

use of org.jpmml.converter.ContinuousFeature in project jpmml-sparkml by jpmml.

the class SparkMLEncoder method getFeatures.

public List<Feature> getFeatures(String column) {
    List<Feature> features = this.columnFeatures.get(column);
    if (features == null) {
        FieldName name = FieldName.create(column);
        DataField dataField = getDataField(name);
        if (dataField == null) {
            dataField = createDataField(name);
        }
        Feature feature;
        DataType dataType = dataField.getDataType();
        switch(dataType) {
            case STRING:
                feature = new WildcardFeature(this, dataField);
                break;
            case INTEGER:
            case DOUBLE:
                feature = new ContinuousFeature(this, dataField);
                break;
            case BOOLEAN:
                feature = new BooleanFeature(this, dataField);
                break;
            default:
                throw new IllegalArgumentException("Data type " + dataType + " is not supported");
        }
        return Collections.singletonList(feature);
    }
    return features;
}
Also used : ContinuousFeature(org.jpmml.converter.ContinuousFeature) DataField(org.dmg.pmml.DataField) DataType(org.dmg.pmml.DataType) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) WildcardFeature(org.jpmml.converter.WildcardFeature) FieldName(org.dmg.pmml.FieldName) BooleanFeature(org.jpmml.converter.BooleanFeature) WildcardFeature(org.jpmml.converter.WildcardFeature)

Aggregations

ContinuousFeature (org.jpmml.converter.ContinuousFeature)26 Feature (org.jpmml.converter.Feature)23 ArrayList (java.util.ArrayList)13 DerivedField (org.dmg.pmml.DerivedField)13 CategoricalFeature (org.jpmml.converter.CategoricalFeature)12 Apply (org.dmg.pmml.Apply)7 FieldName (org.dmg.pmml.FieldName)7 DataField (org.dmg.pmml.DataField)6 Expression (org.dmg.pmml.Expression)6 Predicate (org.dmg.pmml.Predicate)6 SimplePredicate (org.dmg.pmml.SimplePredicate)6 Node (org.dmg.pmml.tree.Node)6 OutputField (org.dmg.pmml.OutputField)4 BooleanFeature (org.jpmml.converter.BooleanFeature)4 Vector (org.apache.spark.ml.linalg.Vector)3 CategoricalLabel (org.jpmml.converter.CategoricalLabel)3 List (java.util.List)2 DocumentBuilder (javax.xml.parsers.DocumentBuilder)2 DataType (org.dmg.pmml.DataType)2 FieldColumnPair (org.dmg.pmml.FieldColumnPair)2