Search in sources :

Example 61 with DataField

use of org.dmg.pmml.DataField in project jpmml-r by jpmml.

the class RandomForestConverter method encodeFormula.

private void encodeFormula(RExpEncoder encoder) {
    RGenericVector randomForest = getObject();
    RGenericVector forest = (RGenericVector) randomForest.getValue("forest");
    RNumberVector<?> y = (RNumberVector<?>) randomForest.getValue("y", true);
    RExp terms = randomForest.getValue("terms");
    final RNumberVector<?> ncat = (RNumberVector<?>) forest.getValue("ncat");
    final RGenericVector xlevels = (RGenericVector) forest.getValue("xlevels");
    RIntegerVector response = (RIntegerVector) terms.getAttributeValue("response");
    FormulaContext context = new FormulaContext() {

        @Override
        public List<String> getCategories(String variable) {
            if (ncat != null && ncat.hasValue(variable)) {
                if ((ncat.getValue(variable)).doubleValue() > 1d) {
                    RStringVector levels = (RStringVector) xlevels.getValue(variable);
                    return levels.getValues();
                }
            }
            return null;
        }

        @Override
        public RGenericVector getData() {
            return null;
        }
    };
    Formula formula = FormulaUtil.createFormula(terms, context, encoder);
    // Dependent variable
    int responseIndex = response.asScalar();
    if (responseIndex != 0) {
        DataField dataField = (DataField) formula.getField(responseIndex - 1);
        if (y instanceof RIntegerVector) {
            dataField = (DataField) encoder.toCategorical(dataField.getName(), RExpUtil.getFactorLevels(y));
        }
        encoder.setLabel(dataField);
    } else {
        throw new IllegalArgumentException();
    }
    RStringVector xlevelNames = xlevels.names();
    // Independent variables
    for (int i = 0; i < xlevelNames.size(); i++) {
        String xlevelName = xlevelNames.getValue(i);
        Feature feature = formula.resolveFeature(FieldName.create(xlevelName));
        encoder.addFeature(feature);
    }
}
Also used : ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) DataField(org.dmg.pmml.DataField)

Example 62 with DataField

use of org.dmg.pmml.DataField in project jpmml-r by jpmml.

the class RExpEncoder method addFeature.

public void addFeature(Field<?> field) {
    Feature feature;
    OpType opType = field.getOpType();
    switch(opType) {
        case CATEGORICAL:
            feature = new CategoricalFeature(this, (DataField) field);
            break;
        case CONTINUOUS:
            feature = new ContinuousFeature(this, field);
            break;
        default:
            throw new IllegalArgumentException();
    }
    addFeature(feature);
}
Also used : ContinuousFeature(org.jpmml.converter.ContinuousFeature) DataField(org.dmg.pmml.DataField) OpType(org.dmg.pmml.OpType) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature)

Example 63 with DataField

use of org.dmg.pmml.DataField in project jpmml-sparkml by jpmml.

the class ImputerModelConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    ImputerModel transformer = getTransformer();
    Double missingValue = transformer.getMissingValue();
    String strategy = transformer.getStrategy();
    Dataset<Row> surrogateDF = transformer.surrogateDF();
    String[] inputCols = transformer.getInputCols();
    String[] outputCols = transformer.getOutputCols();
    if (inputCols.length != outputCols.length) {
        throw new IllegalArgumentException();
    }
    MissingValueTreatmentMethod missingValueTreatmentMethod = parseStrategy(strategy);
    List<Row> surrogateRows = surrogateDF.collectAsList();
    if (surrogateRows.size() != 1) {
        throw new IllegalArgumentException();
    }
    Row surrogateRow = surrogateRows.get(0);
    List<Feature> result = new ArrayList<>();
    for (int i = 0; i < inputCols.length; i++) {
        String inputCol = inputCols[i];
        String outputCol = outputCols[i];
        Feature feature = encoder.getOnlyFeature(inputCol);
        Field<?> field = encoder.getField(feature.getName());
        if (field instanceof DataField) {
            DataField dataField = (DataField) field;
            Object surrogate = surrogateRow.getAs(inputCol);
            MissingValueDecorator missingValueDecorator = new MissingValueDecorator().setMissingValueReplacement(ValueUtil.formatValue(surrogate)).setMissingValueTreatment(missingValueTreatmentMethod);
            if (missingValue != null && !missingValue.isNaN()) {
                missingValueDecorator.addValues(ValueUtil.formatValue(missingValue));
            }
            encoder.addDecorator(feature.getName(), missingValueDecorator);
        } else {
            throw new IllegalArgumentException();
        }
        result.add(feature);
    }
    return result;
}
Also used : ArrayList(java.util.ArrayList) MissingValueDecorator(org.jpmml.converter.MissingValueDecorator) Feature(org.jpmml.converter.Feature) DataField(org.dmg.pmml.DataField) ImputerModel(org.apache.spark.ml.feature.ImputerModel) Row(org.apache.spark.sql.Row) MissingValueTreatmentMethod(org.dmg.pmml.MissingValueTreatmentMethod)

Example 64 with DataField

use of org.dmg.pmml.DataField in project jpmml-sparkml by jpmml.

the class IndexToStringConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    IndexToString transformer = getTransformer();
    DataField dataField = encoder.createDataField(formatName(transformer), OpType.CATEGORICAL, DataType.STRING, Arrays.asList(transformer.getLabels()));
    return Collections.<Feature>singletonList(new CategoricalFeature(encoder, dataField));
}
Also used : DataField(org.dmg.pmml.DataField) IndexToString(org.apache.spark.ml.feature.IndexToString) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature)

Example 65 with DataField

use of org.dmg.pmml.DataField in project jpmml-sparkml by jpmml.

the class StringIndexerModelConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    StringIndexerModel transformer = getTransformer();
    Feature feature = encoder.getOnlyFeature(transformer.getInputCol());
    List<String> categories = new ArrayList<>();
    categories.addAll(Arrays.asList(transformer.labels()));
    String handleInvalid = transformer.getHandleInvalid();
    Field<?> field = encoder.toCategorical(feature.getName(), categories);
    if (field instanceof DataField) {
        DataField dataField = (DataField) field;
        InvalidValueTreatmentMethod invalidValueTreatmentMethod;
        switch(handleInvalid) {
            case "keep":
                invalidValueTreatmentMethod = InvalidValueTreatmentMethod.AS_IS;
                break;
            case "error":
                invalidValueTreatmentMethod = InvalidValueTreatmentMethod.RETURN_INVALID;
                break;
            default:
                throw new IllegalArgumentException(handleInvalid);
        }
        InvalidValueDecorator invalidValueDecorator = new InvalidValueDecorator().setInvalidValueTreatment(invalidValueTreatmentMethod);
        encoder.addDecorator(dataField.getName(), invalidValueDecorator);
    } else if (field instanceof DerivedField) {
    // Ignored
    } else {
        throw new IllegalArgumentException();
    }
    switch(handleInvalid) {
        case "keep":
            Apply setApply = PMMLUtil.createApply("isIn", feature.ref());
            for (String category : categories) {
                setApply.addExpressions(PMMLUtil.createConstant(category, feature.getDataType()));
            }
            categories.add(StringIndexerModelConverter.LABEL_UNKNOWN);
            Apply apply = PMMLUtil.createApply("if", setApply, feature.ref(), PMMLUtil.createConstant(StringIndexerModelConverter.LABEL_UNKNOWN, DataType.STRING));
            field = encoder.createDerivedField(FeatureUtil.createName("handleInvalid", feature), OpType.CATEGORICAL, feature.getDataType(), apply);
            break;
        default:
            break;
    }
    return Collections.<Feature>singletonList(new CategoricalFeature(encoder, field, categories));
}
Also used : Apply(org.dmg.pmml.Apply) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) StringIndexerModel(org.apache.spark.ml.feature.StringIndexerModel) CategoricalFeature(org.jpmml.converter.CategoricalFeature) InvalidValueTreatmentMethod(org.dmg.pmml.InvalidValueTreatmentMethod) InvalidValueDecorator(org.jpmml.converter.InvalidValueDecorator) DataField(org.dmg.pmml.DataField) DerivedField(org.dmg.pmml.DerivedField)

Aggregations

DataField (org.dmg.pmml.DataField)101 Test (org.junit.Test)51 DataDictionary (org.dmg.pmml.DataDictionary)42 MiningField (org.dmg.pmml.MiningField)42 MiningSchema (org.dmg.pmml.MiningSchema)30 PMMLModelTestUtils.getRandomDataField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomDataField)28 RegressionModel (org.dmg.pmml.regression.RegressionModel)27 CommonTestingUtils.getFieldsFromDataDictionary (org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary)27 FieldName (org.dmg.pmml.FieldName)24 Model (org.dmg.pmml.Model)24 PMMLModelTestUtils.getDataField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getDataField)22 DataType (org.dmg.pmml.DataType)19 OutputField (org.dmg.pmml.OutputField)19 PMMLModelTestUtils.getRandomMiningField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomMiningField)19 PMMLModelTestUtils.getMiningField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getMiningField)18 ArrayList (java.util.ArrayList)17 List (java.util.List)17 PMML (org.dmg.pmml.PMML)17 Collectors (java.util.stream.Collectors)16 OpType (org.dmg.pmml.OpType)15