Search in sources :

Example 6 with Feature

use of org.jpmml.converter.Feature in project jpmml-r by jpmml.

the class SVMConverter method scale.

private List<Feature> scale(List<Feature> features, RExpEncoder encoder) {
    RGenericVector svm = getObject();
    RDoubleVector sv = (RDoubleVector) svm.getValue("SV");
    RBooleanVector scaled = (RBooleanVector) svm.getValue("scaled");
    RGenericVector xScale = (RGenericVector) svm.getValue("x.scale");
    RStringVector rowNames = sv.dimnames(0);
    RStringVector columnNames = sv.dimnames(1);
    if ((scaled.size() != columnNames.size()) || (scaled.size() != features.size())) {
        throw new IllegalArgumentException();
    }
    RDoubleVector xScaledCenter = null;
    RDoubleVector xScaledScale = null;
    if (xScale != null) {
        xScaledCenter = (RDoubleVector) xScale.getValue("scaled:center");
        xScaledScale = (RDoubleVector) xScale.getValue("scaled:scale");
    }
    List<Feature> result = new ArrayList<>();
    for (int i = 0; i < columnNames.size(); i++) {
        String columnName = columnNames.getValue(i);
        Feature feature = features.get(i);
        if (scaled.getValue(i)) {
            feature = feature.toContinuousFeature();
            FieldName name = FeatureUtil.createName("scale", feature);
            DerivedField derivedField = encoder.getDerivedField(name);
            if (derivedField == null) {
                Double center = xScaledCenter.getValue(columnName);
                Double scale = xScaledScale.getValue(columnName);
                Apply apply = PMMLUtil.createApply("/", PMMLUtil.createApply("-", feature.ref(), PMMLUtil.createConstant(center)), PMMLUtil.createConstant(scale));
                derivedField = encoder.createDerivedField(name, OpType.CONTINUOUS, DataType.DOUBLE, apply);
            }
            feature = new ContinuousFeature(encoder, derivedField);
        }
        result.add(feature);
    }
    return result;
}
Also used : Apply(org.dmg.pmml.Apply) ArrayList(java.util.ArrayList) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) FieldName(org.dmg.pmml.FieldName) DerivedField(org.dmg.pmml.DerivedField)

Example 7 with Feature

use of org.jpmml.converter.Feature in project jpmml-r by jpmml.

the class ElmNNConverter method encodeSchema.

@Override
public void encodeSchema(RExpEncoder encoder) {
    RGenericVector elmNN = getObject();
    final RGenericVector model;
    try {
        model = (RGenericVector) elmNN.getValue("model");
    } catch (IllegalArgumentException iae) {
        throw new IllegalArgumentException("No model frame information. Please initialize the \'model\' element", iae);
    }
    RExp terms = model.getAttributeValue("terms");
    RIntegerVector response = (RIntegerVector) terms.getAttributeValue("response");
    RStringVector columns = (RStringVector) terms.getAttributeValue("columns");
    FormulaContext context = new ModelFrameFormulaContext(model);
    Formula formula = FormulaUtil.createFormula(terms, context, encoder);
    // Dependent variable
    int responseIndex = response.asScalar();
    if (responseIndex != 0) {
        DataField dataField = (DataField) formula.getField(responseIndex - 1);
        encoder.setLabel(dataField);
    }
    // Independent variables
    for (int i = 0; i < columns.size(); i++) {
        String column = columns.getValue(i);
        if (i == 0 && "(Intercept)".equals(column)) {
            continue;
        }
        Feature feature = formula.resolveFeature(column);
        encoder.addFeature(feature);
    }
}
Also used : DataField(org.dmg.pmml.DataField) Feature(org.jpmml.converter.Feature)

Example 8 with Feature

use of org.jpmml.converter.Feature in project jpmml-r by jpmml.

the class PreProcessEncoder method filter.

private Schema filter(Schema schema) {
    Function<Feature, Feature> function = new Function<Feature, Feature>() {

        @Override
        public Feature apply(Feature feature) {
            Expression expression = encodeExpression(feature);
            if (expression == null) {
                return feature;
            }
            DerivedField derivedField = createDerivedField(FeatureUtil.createName("preProcess", feature), OpType.CONTINUOUS, DataType.DOUBLE, expression);
            return new ContinuousFeature(PreProcessEncoder.this, derivedField);
        }
    };
    return schema.toTransformedSchema(function);
}
Also used : Function(java.util.function.Function) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Expression(org.dmg.pmml.Expression) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) DerivedField(org.dmg.pmml.DerivedField)

Example 9 with Feature

use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.

the class LogisticRegressionModelConverter method encodeModel.

@Override
public RegressionModel encodeModel(Schema schema) {
    LogisticRegressionModel model = getTransformer();
    CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
    if (categoricalLabel.size() == 2) {
        RegressionModel regressionModel = RegressionModelUtil.createBinaryLogisticClassification(schema.getFeatures(), VectorUtil.toList(model.coefficients()), model.intercept(), RegressionModel.NormalizationMethod.LOGIT, true, schema).setOutput(null);
        return regressionModel;
    } else if (categoricalLabel.size() > 2) {
        Matrix coefficientMatrix = model.coefficientMatrix();
        Vector interceptVector = model.interceptVector();
        List<? extends Feature> features = schema.getFeatures();
        List<RegressionTable> regressionTables = new ArrayList<>();
        for (int i = 0; i < categoricalLabel.size(); i++) {
            RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(features, MatrixUtil.getRow(coefficientMatrix, i), interceptVector.apply(i)).setTargetCategory(categoricalLabel.getValue(i));
            regressionTables.add(regressionTable);
        }
        RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), regressionTables).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX);
        return regressionModel;
    } else {
        throw new IllegalArgumentException();
    }
}
Also used : Matrix(org.apache.spark.ml.linalg.Matrix) CategoricalLabel(org.jpmml.converter.CategoricalLabel) LogisticRegressionModel(org.apache.spark.ml.classification.LogisticRegressionModel) ArrayList(java.util.ArrayList) List(java.util.List) Vector(org.apache.spark.ml.linalg.Vector) Feature(org.jpmml.converter.Feature) RegressionTable(org.dmg.pmml.regression.RegressionTable) LogisticRegressionModel(org.apache.spark.ml.classification.LogisticRegressionModel) RegressionModel(org.dmg.pmml.regression.RegressionModel)

Example 10 with Feature

use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.

the class TreeModelUtil method encodeNode.

public static Node encodeNode(org.apache.spark.ml.tree.Node node, PredicateManager predicateManager, Map<FieldName, Set<String>> parentFieldValues, MiningFunction miningFunction, Schema schema) {
    if (node instanceof InternalNode) {
        InternalNode internalNode = (InternalNode) node;
        Map<FieldName, Set<String>> leftFieldValues = parentFieldValues;
        Map<FieldName, Set<String>> rightFieldValues = parentFieldValues;
        Predicate leftPredicate;
        Predicate rightPredicate;
        Split split = internalNode.split();
        Feature feature = schema.getFeature(split.featureIndex());
        if (split instanceof ContinuousSplit) {
            ContinuousSplit continuousSplit = (ContinuousSplit) split;
            double threshold = continuousSplit.threshold();
            if (feature instanceof BooleanFeature) {
                BooleanFeature booleanFeature = (BooleanFeature) feature;
                if (threshold != 0.5d) {
                    throw new IllegalArgumentException();
                }
                leftPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
                rightPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
            } else {
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                String value = ValueUtil.formatValue(threshold);
                leftPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
                rightPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
            }
        } else if (split instanceof CategoricalSplit) {
            CategoricalSplit categoricalSplit = (CategoricalSplit) split;
            double[] leftCategories = categoricalSplit.leftCategories();
            double[] rightCategories = categoricalSplit.rightCategories();
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature) feature;
                SimplePredicate.Operator leftOperator;
                SimplePredicate.Operator rightOperator;
                if (Arrays.equals(TRUE, leftCategories) && Arrays.equals(FALSE, rightCategories)) {
                    leftOperator = SimplePredicate.Operator.EQUAL;
                    rightOperator = SimplePredicate.Operator.NOT_EQUAL;
                } else if (Arrays.equals(FALSE, leftCategories) && Arrays.equals(TRUE, rightCategories)) {
                    leftOperator = SimplePredicate.Operator.NOT_EQUAL;
                    rightOperator = SimplePredicate.Operator.EQUAL;
                } else {
                    throw new IllegalArgumentException();
                }
                String value = ValueUtil.formatValue(binaryFeature.getValue());
                leftPredicate = predicateManager.createSimplePredicate(binaryFeature, leftOperator, value);
                rightPredicate = predicateManager.createSimplePredicate(binaryFeature, rightOperator, value);
            } else if (feature instanceof CategoricalFeature) {
                CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
                FieldName name = categoricalFeature.getName();
                List<String> values = categoricalFeature.getValues();
                if (values.size() != (leftCategories.length + rightCategories.length)) {
                    throw new IllegalArgumentException();
                }
                final Set<String> parentValues = parentFieldValues.get(name);
                com.google.common.base.Predicate<String> valueFilter = new com.google.common.base.Predicate<String>() {

                    @Override
                    public boolean apply(String value) {
                        if (parentValues != null) {
                            return parentValues.contains(value);
                        }
                        return true;
                    }
                };
                List<String> leftValues = selectValues(values, leftCategories, valueFilter);
                List<String> rightValues = selectValues(values, rightCategories, valueFilter);
                leftFieldValues = new HashMap<>(parentFieldValues);
                leftFieldValues.put(name, new HashSet<>(leftValues));
                rightFieldValues = new HashMap<>(parentFieldValues);
                rightFieldValues.put(name, new HashSet<>(rightValues));
                leftPredicate = predicateManager.createSimpleSetPredicate(categoricalFeature, leftValues);
                rightPredicate = predicateManager.createSimpleSetPredicate(categoricalFeature, rightValues);
            } else {
                throw new IllegalArgumentException();
            }
        } else {
            throw new IllegalArgumentException();
        }
        Node result = new Node();
        Node leftChild = encodeNode(internalNode.leftChild(), predicateManager, leftFieldValues, miningFunction, schema).setPredicate(leftPredicate);
        Node rightChild = encodeNode(internalNode.rightChild(), predicateManager, rightFieldValues, miningFunction, schema).setPredicate(rightPredicate);
        result.addNodes(leftChild, rightChild);
        return result;
    } else if (node instanceof LeafNode) {
        LeafNode leafNode = (LeafNode) node;
        Node result = new Node();
        switch(miningFunction) {
            case REGRESSION:
                {
                    String score = ValueUtil.formatValue(node.prediction());
                    result.setScore(score);
                }
                break;
            case CLASSIFICATION:
                {
                    CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
                    int index = ValueUtil.asInt(node.prediction());
                    result.setScore(categoricalLabel.getValue(index));
                    ImpurityCalculator impurityCalculator = node.impurityStats();
                    result.setRecordCount((double) impurityCalculator.count());
                    double[] stats = impurityCalculator.stats();
                    for (int i = 0; i < stats.length; i++) {
                        ScoreDistribution scoreDistribution = new ScoreDistribution(categoricalLabel.getValue(i), stats[i]);
                        result.addScoreDistributions(scoreDistribution);
                    }
                }
                break;
            default:
                throw new UnsupportedOperationException();
        }
        return result;
    } else {
        throw new IllegalArgumentException();
    }
}
Also used : HashSet(java.util.HashSet) Set(java.util.Set) InternalNode(org.apache.spark.ml.tree.InternalNode) Node(org.dmg.pmml.tree.Node) LeafNode(org.apache.spark.ml.tree.LeafNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) BinaryFeature(org.jpmml.converter.BinaryFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) ContinuousSplit(org.apache.spark.ml.tree.ContinuousSplit) LeafNode(org.apache.spark.ml.tree.LeafNode) FieldName(org.dmg.pmml.FieldName) BinaryFeature(org.jpmml.converter.BinaryFeature) ScoreDistribution(org.dmg.pmml.ScoreDistribution) ContinuousFeature(org.jpmml.converter.ContinuousFeature) ImpurityCalculator(org.apache.spark.mllib.tree.impurity.ImpurityCalculator) CategoricalSplit(org.apache.spark.ml.tree.CategoricalSplit) CategoricalLabel(org.jpmml.converter.CategoricalLabel) InternalNode(org.apache.spark.ml.tree.InternalNode) Split(org.apache.spark.ml.tree.Split) ContinuousSplit(org.apache.spark.ml.tree.ContinuousSplit) CategoricalSplit(org.apache.spark.ml.tree.CategoricalSplit)

Aggregations

Feature (org.jpmml.converter.Feature)53 ContinuousFeature (org.jpmml.converter.ContinuousFeature)30 ArrayList (java.util.ArrayList)27 CategoricalFeature (org.jpmml.converter.CategoricalFeature)19 DerivedField (org.dmg.pmml.DerivedField)14 DataField (org.dmg.pmml.DataField)13 FieldName (org.dmg.pmml.FieldName)10 Apply (org.dmg.pmml.Apply)9 BooleanFeature (org.jpmml.converter.BooleanFeature)9 BinaryFeature (org.jpmml.converter.BinaryFeature)7 List (java.util.List)6 Expression (org.dmg.pmml.Expression)6 SimplePredicate (org.dmg.pmml.SimplePredicate)6 Vector (org.apache.spark.ml.linalg.Vector)5 Predicate (org.dmg.pmml.Predicate)5 Node (org.dmg.pmml.tree.Node)5 DocumentFeature (org.jpmml.sparkml.DocumentFeature)5 InteractionFeature (org.jpmml.converter.InteractionFeature)4 DocumentBuilder (javax.xml.parsers.DocumentBuilder)3 Transformer (org.apache.spark.ml.Transformer)3