Search in sources :

Example 1 with InteractionFeature

use of org.jpmml.converter.InteractionFeature in project jpmml-sparkml by jpmml.

the class InteractionConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    Interaction transformer = getTransformer();
    String name = "";
    List<Feature> result = new ArrayList<>();
    String[] inputCols = transformer.getInputCols();
    for (int i = 0; i < inputCols.length; i++) {
        String inputCol = inputCols[i];
        List<Feature> features = encoder.getFeatures(inputCol);
        if (i == 0) {
            name = inputCol;
            result = features;
        } else {
            name += (":" + inputCol);
            List<Feature> interactionFeatures = new ArrayList<>();
            int index = 0;
            for (Feature left : result) {
                for (Feature right : features) {
                    interactionFeatures.add(new InteractionFeature(encoder, FieldName.create(name + "[" + index + "]"), DataType.DOUBLE, Arrays.asList(left, right)));
                    index++;
                }
            }
            result = interactionFeatures;
        }
    }
    return result;
}
Also used : InteractionFeature(org.jpmml.converter.InteractionFeature) Interaction(org.apache.spark.ml.feature.Interaction) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) InteractionFeature(org.jpmml.converter.InteractionFeature)

Example 2 with InteractionFeature

use of org.jpmml.converter.InteractionFeature in project jpmml-r by jpmml.

the class EarthConverter method encodeSchema.

@Override
public void encodeSchema(RExpEncoder encoder) {
    RGenericVector earth = getObject();
    RDoubleVector dirs = (RDoubleVector) earth.getValue("dirs");
    RDoubleVector cuts = (RDoubleVector) earth.getValue("cuts");
    RDoubleVector selectedTerms = (RDoubleVector) earth.getValue("selected.terms");
    RDoubleVector coefficients = (RDoubleVector) earth.getValue("coefficients");
    RExp terms = earth.getValue("terms");
    final RGenericVector xlevels;
    try {
        xlevels = (RGenericVector) earth.getValue("xlevels");
    } catch (IllegalArgumentException iae) {
        throw new IllegalArgumentException("No variable levels information. Please initialize the \'xlevels\' element", iae);
    }
    RStringVector dirsRows = dirs.dimnames(0);
    RStringVector dirsColumns = dirs.dimnames(1);
    RStringVector cutsRows = cuts.dimnames(0);
    RStringVector cutsColumns = cuts.dimnames(1);
    if (!(dirsRows.getValues()).equals(cutsRows.getValues()) || !(dirsColumns.getValues()).equals(cutsColumns.getValues())) {
        throw new IllegalArgumentException();
    }
    int rows = dirsRows.size();
    int columns = dirsColumns.size();
    List<String> predictorNames = dirsColumns.getValues();
    FormulaContext context = new FormulaContext() {

        @Override
        public List<String> getCategories(String variable) {
            if (xlevels.hasValue(variable)) {
                RStringVector levels = (RStringVector) xlevels.getValue(variable);
                return levels.getValues();
            }
            return null;
        }

        @Override
        public RGenericVector getData() {
            return null;
        }
    };
    Formula formula = FormulaUtil.createFormula(terms, context, encoder);
    // Dependent variable
    {
        RStringVector yNames = coefficients.dimnames(1);
        FieldName name = FieldName.create(yNames.asScalar());
        DataField dataField = (DataField) encoder.getField(name);
        encoder.setLabel(dataField);
    }
    // Independent variables
    for (int i = 1; i < selectedTerms.size(); i++) {
        int termIndex = ValueUtil.asInt(selectedTerms.getValue(i)) - 1;
        List<Double> dirsRow = FortranMatrixUtil.getRow(dirs.getValues(), rows, columns, termIndex);
        List<Double> cutsRow = FortranMatrixUtil.getRow(cuts.getValues(), rows, columns, termIndex);
        List<Feature> features = new ArrayList<>();
        predictors: for (int j = 0; j < predictorNames.size(); j++) {
            String predictorName = predictorNames.get(j);
            int dir = ValueUtil.asInt(dirsRow.get(j));
            double cut = cutsRow.get(j);
            if (dir == 0) {
                continue predictors;
            }
            Feature feature = formula.resolveFeature(predictorName);
            switch(dir) {
                case -1:
                case 1:
                    {
                        feature = feature.toContinuousFeature();
                        FieldName name = FieldName.create(formatHingeFunction(dir, feature, cut));
                        DerivedField derivedField = encoder.getDerivedField(name);
                        if (derivedField == null) {
                            Apply apply = createHingeFunction(dir, feature, cut);
                            derivedField = encoder.createDerivedField(name, OpType.CONTINUOUS, DataType.DOUBLE, apply);
                        }
                        feature = new ContinuousFeature(encoder, derivedField);
                    }
                    break;
                case 2:
                    break;
                default:
                    throw new IllegalArgumentException();
            }
            features.add(feature);
        }
        Feature feature;
        if (features.size() == 1) {
            feature = features.get(0);
        } else if (features.size() > 1) {
            feature = new InteractionFeature(encoder, FieldName.create(dirsRows.getValue(i)), DataType.DOUBLE, features);
        } else {
            throw new IllegalArgumentException();
        }
        encoder.addFeature(feature);
    }
}
Also used : InteractionFeature(org.jpmml.converter.InteractionFeature) Apply(org.dmg.pmml.Apply) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) InteractionFeature(org.jpmml.converter.InteractionFeature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) DataField(org.dmg.pmml.DataField) FieldName(org.dmg.pmml.FieldName) DerivedField(org.dmg.pmml.DerivedField)

Example 3 with InteractionFeature

use of org.jpmml.converter.InteractionFeature in project jpmml-r by jpmml.

the class Formula method resolveFeature.

public Feature resolveFeature(String name) {
    RExpEncoder encoder = getEncoder();
    List<String> variables = split(name);
    if (variables.size() == 1) {
        return resolveFeature(FieldName.create(name));
    } else {
        List<Feature> variableFeatures = new ArrayList<>();
        for (String variable : variables) {
            Feature variableFeature = resolveFeature(FieldName.create(variable));
            variableFeatures.add(variableFeature);
        }
        return new InteractionFeature(encoder, FieldName.create(name), DataType.DOUBLE, variableFeatures);
    }
}
Also used : InteractionFeature(org.jpmml.converter.InteractionFeature) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) PowerFeature(org.jpmml.converter.PowerFeature) BinaryFeature(org.jpmml.converter.BinaryFeature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) InteractionFeature(org.jpmml.converter.InteractionFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature)

Aggregations

ArrayList (java.util.ArrayList)3 Feature (org.jpmml.converter.Feature)3 InteractionFeature (org.jpmml.converter.InteractionFeature)3 ContinuousFeature (org.jpmml.converter.ContinuousFeature)2 Interaction (org.apache.spark.ml.feature.Interaction)1 Apply (org.dmg.pmml.Apply)1 DataField (org.dmg.pmml.DataField)1 DerivedField (org.dmg.pmml.DerivedField)1 FieldName (org.dmg.pmml.FieldName)1 BinaryFeature (org.jpmml.converter.BinaryFeature)1 BooleanFeature (org.jpmml.converter.BooleanFeature)1 CategoricalFeature (org.jpmml.converter.CategoricalFeature)1 PowerFeature (org.jpmml.converter.PowerFeature)1