Search in sources :

Example 21 with DataField

use of org.dmg.pmml.DataField in project jpmml-r by jpmml.

the class RandomForestConverter method encodeFormula.

private void encodeFormula(RExpEncoder encoder) {
    RGenericVector randomForest = getObject();
    RGenericVector forest = (RGenericVector) randomForest.getValue("forest");
    RNumberVector<?> y = (RNumberVector<?>) randomForest.getValue("y", true);
    RExp terms = randomForest.getValue("terms");
    final RNumberVector<?> ncat = (RNumberVector<?>) forest.getValue("ncat");
    final RGenericVector xlevels = (RGenericVector) forest.getValue("xlevels");
    RIntegerVector response = (RIntegerVector) terms.getAttributeValue("response");
    FormulaContext context = new FormulaContext() {

        @Override
        public List<String> getCategories(String variable) {
            if (ncat != null && ncat.hasValue(variable)) {
                if ((ncat.getValue(variable)).doubleValue() > 1d) {
                    RStringVector levels = (RStringVector) xlevels.getValue(variable);
                    return levels.getValues();
                }
            }
            return null;
        }

        @Override
        public RGenericVector getData() {
            return null;
        }
    };
    Formula formula = FormulaUtil.createFormula(terms, context, encoder);
    // Dependent variable
    int responseIndex = response.asScalar();
    if (responseIndex != 0) {
        DataField dataField = (DataField) formula.getField(responseIndex - 1);
        if (y instanceof RIntegerVector) {
            dataField = (DataField) encoder.toCategorical(dataField.getName(), RExpUtil.getFactorLevels(y));
        }
        encoder.setLabel(dataField);
    } else {
        throw new IllegalArgumentException();
    }
    RStringVector xlevelNames = xlevels.names();
    // Independent variables
    for (int i = 0; i < xlevelNames.size(); i++) {
        String xlevelName = xlevelNames.getValue(i);
        Feature feature = formula.resolveFeature(FieldName.create(xlevelName));
        encoder.addFeature(feature);
    }
}
Also used : ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) DataField(org.dmg.pmml.DataField)

Example 22 with DataField

use of org.dmg.pmml.DataField in project jpmml-r by jpmml.

the class RExpEncoder method addFeature.

public void addFeature(Field<?> field) {
    Feature feature;
    OpType opType = field.getOpType();
    switch(opType) {
        case CATEGORICAL:
            feature = new CategoricalFeature(this, (DataField) field);
            break;
        case CONTINUOUS:
            feature = new ContinuousFeature(this, field);
            break;
        default:
            throw new IllegalArgumentException();
    }
    addFeature(feature);
}
Also used : ContinuousFeature(org.jpmml.converter.ContinuousFeature) DataField(org.dmg.pmml.DataField) OpType(org.dmg.pmml.OpType) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature)

Example 23 with DataField

use of org.dmg.pmml.DataField in project jpmml-sparkml by jpmml.

the class ImputerModelConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    ImputerModel transformer = getTransformer();
    Double missingValue = transformer.getMissingValue();
    String strategy = transformer.getStrategy();
    Dataset<Row> surrogateDF = transformer.surrogateDF();
    String[] inputCols = transformer.getInputCols();
    String[] outputCols = transformer.getOutputCols();
    if (inputCols.length != outputCols.length) {
        throw new IllegalArgumentException();
    }
    MissingValueTreatmentMethod missingValueTreatmentMethod = parseStrategy(strategy);
    List<Row> surrogateRows = surrogateDF.collectAsList();
    if (surrogateRows.size() != 1) {
        throw new IllegalArgumentException();
    }
    Row surrogateRow = surrogateRows.get(0);
    List<Feature> result = new ArrayList<>();
    for (int i = 0; i < inputCols.length; i++) {
        String inputCol = inputCols[i];
        String outputCol = outputCols[i];
        Feature feature = encoder.getOnlyFeature(inputCol);
        Field<?> field = encoder.getField(feature.getName());
        if (field instanceof DataField) {
            DataField dataField = (DataField) field;
            Object surrogate = surrogateRow.getAs(inputCol);
            MissingValueDecorator missingValueDecorator = new MissingValueDecorator().setMissingValueReplacement(ValueUtil.formatValue(surrogate)).setMissingValueTreatment(missingValueTreatmentMethod);
            if (missingValue != null && !missingValue.isNaN()) {
                missingValueDecorator.addValues(ValueUtil.formatValue(missingValue));
            }
            encoder.addDecorator(feature.getName(), missingValueDecorator);
        } else {
            throw new IllegalArgumentException();
        }
        result.add(feature);
    }
    return result;
}
Also used : ArrayList(java.util.ArrayList) MissingValueDecorator(org.jpmml.converter.MissingValueDecorator) Feature(org.jpmml.converter.Feature) DataField(org.dmg.pmml.DataField) ImputerModel(org.apache.spark.ml.feature.ImputerModel) Row(org.apache.spark.sql.Row) MissingValueTreatmentMethod(org.dmg.pmml.MissingValueTreatmentMethod)

Example 24 with DataField

use of org.dmg.pmml.DataField in project jpmml-sparkml by jpmml.

the class IndexToStringConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    IndexToString transformer = getTransformer();
    DataField dataField = encoder.createDataField(formatName(transformer), OpType.CATEGORICAL, DataType.STRING, Arrays.asList(transformer.getLabels()));
    return Collections.<Feature>singletonList(new CategoricalFeature(encoder, dataField));
}
Also used : DataField(org.dmg.pmml.DataField) IndexToString(org.apache.spark.ml.feature.IndexToString) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature)

Example 25 with DataField

use of org.dmg.pmml.DataField in project jpmml-sparkml by jpmml.

the class StringIndexerModelConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    StringIndexerModel transformer = getTransformer();
    Feature feature = encoder.getOnlyFeature(transformer.getInputCol());
    List<String> categories = new ArrayList<>();
    categories.addAll(Arrays.asList(transformer.labels()));
    String handleInvalid = transformer.getHandleInvalid();
    Field<?> field = encoder.toCategorical(feature.getName(), categories);
    if (field instanceof DataField) {
        DataField dataField = (DataField) field;
        InvalidValueTreatmentMethod invalidValueTreatmentMethod;
        switch(handleInvalid) {
            case "keep":
                invalidValueTreatmentMethod = InvalidValueTreatmentMethod.AS_IS;
                break;
            case "error":
                invalidValueTreatmentMethod = InvalidValueTreatmentMethod.RETURN_INVALID;
                break;
            default:
                throw new IllegalArgumentException(handleInvalid);
        }
        InvalidValueDecorator invalidValueDecorator = new InvalidValueDecorator().setInvalidValueTreatment(invalidValueTreatmentMethod);
        encoder.addDecorator(dataField.getName(), invalidValueDecorator);
    } else if (field instanceof DerivedField) {
    // Ignored
    } else {
        throw new IllegalArgumentException();
    }
    switch(handleInvalid) {
        case "keep":
            Apply setApply = PMMLUtil.createApply("isIn", feature.ref());
            for (String category : categories) {
                setApply.addExpressions(PMMLUtil.createConstant(category, feature.getDataType()));
            }
            categories.add(StringIndexerModelConverter.LABEL_UNKNOWN);
            Apply apply = PMMLUtil.createApply("if", setApply, feature.ref(), PMMLUtil.createConstant(StringIndexerModelConverter.LABEL_UNKNOWN, DataType.STRING));
            field = encoder.createDerivedField(FeatureUtil.createName("handleInvalid", feature), OpType.CATEGORICAL, feature.getDataType(), apply);
            break;
        default:
            break;
    }
    return Collections.<Feature>singletonList(new CategoricalFeature(encoder, field, categories));
}
Also used : Apply(org.dmg.pmml.Apply) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) StringIndexerModel(org.apache.spark.ml.feature.StringIndexerModel) CategoricalFeature(org.jpmml.converter.CategoricalFeature) InvalidValueTreatmentMethod(org.dmg.pmml.InvalidValueTreatmentMethod) InvalidValueDecorator(org.jpmml.converter.InvalidValueDecorator) DataField(org.dmg.pmml.DataField) DerivedField(org.dmg.pmml.DerivedField)

Aggregations

DataField (org.dmg.pmml.DataField)26 Feature (org.jpmml.converter.Feature)13 FieldName (org.dmg.pmml.FieldName)12 ArrayList (java.util.ArrayList)9 ContinuousFeature (org.jpmml.converter.ContinuousFeature)8 CategoricalFeature (org.jpmml.converter.CategoricalFeature)5 DataType (org.dmg.pmml.DataType)4 DerivedField (org.dmg.pmml.DerivedField)4 OpType (org.dmg.pmml.OpType)4 Apply (org.dmg.pmml.Apply)3 CategoricalLabel (org.jpmml.converter.CategoricalLabel)3 ContinuousLabel (org.jpmml.converter.ContinuousLabel)3 Label (org.jpmml.converter.Label)3 Function (com.google.common.base.Function)2 MiningFunction (org.dmg.pmml.MiningFunction)2 BooleanFeature (org.jpmml.converter.BooleanFeature)2 InputField (org.jpmml.evaluator.InputField)2 OutputField (org.jpmml.evaluator.OutputField)2 TargetField (org.jpmml.evaluator.TargetField)2 Field (org.openscoring.common.Field)2