Search in sources :

Example 6 with DataField

use of org.dmg.pmml.DataField in project jpmml-r by jpmml.

the class RangerConverter method encodeSchema.

@Override
public void encodeSchema(RExpEncoder encoder) {
    RGenericVector ranger = getObject();
    RGenericVector forest;
    try {
        forest = (RGenericVector) ranger.getValue("forest");
    } catch (IllegalArgumentException iae) {
        throw new IllegalArgumentException("No forest information. Please initialize the \'forest\' element", iae);
    }
    RGenericVector variableLevels;
    try {
        variableLevels = (RGenericVector) ranger.getValue("variable.levels");
    } catch (IllegalArgumentException iae) {
        throw new IllegalArgumentException("No variable levels information. Please initialize the \'variable.levels\' element", iae);
    }
    RStringVector treeType = (RStringVector) ranger.getValue("treetype");
    // Dependent variable
    {
        FieldName name = FieldName.create("_target");
        DataField dataField;
        switch(treeType.asScalar()) {
            case "Regression":
                {
                    dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.DOUBLE);
                }
                break;
            case "Classification":
            case "Probability estimation":
                {
                    RStringVector levels = (RStringVector) forest.getValue("levels");
                    dataField = encoder.createDataField(name, OpType.CATEGORICAL, null, levels.getValues());
                }
                break;
            default:
                throw new IllegalArgumentException();
        }
        encoder.setLabel(dataField);
    }
    RBooleanVector isOrdered = (RBooleanVector) forest.getValue("is.ordered");
    RStringVector independentVariableNames = (RStringVector) forest.getValue("independent.variable.names");
    // Independent variables
    for (int i = 0; i < independentVariableNames.size(); i++) {
        if (!isOrdered.getValue(i + 1)) {
            throw new IllegalArgumentException();
        }
        String independentVariableName = independentVariableNames.getValue(i);
        FieldName name = FieldName.create(independentVariableName);
        DataField dataField;
        if (variableLevels.hasValue(independentVariableName)) {
            RStringVector levels = (RStringVector) variableLevels.getValue(independentVariableName);
            dataField = encoder.createDataField(name, OpType.CATEGORICAL, DataType.STRING, levels.getValues());
        } else {
            dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.DOUBLE);
        }
        encoder.addFeature(dataField);
    }
}
Also used : DataField(org.dmg.pmml.DataField) FieldName(org.dmg.pmml.FieldName)

Example 7 with DataField

use of org.dmg.pmml.DataField in project jpmml-r by jpmml.

the class BinaryTreeConverter method encodeResponse.

private void encodeResponse(S4Object responses, RExpEncoder encoder) {
    RGenericVector variables = (RGenericVector) responses.getAttributeValue("variables");
    RBooleanVector is_nominal = (RBooleanVector) responses.getAttributeValue("is_nominal");
    RGenericVector levels = (RGenericVector) responses.getAttributeValue("levels");
    RStringVector variableNames = variables.names();
    String variableName = variableNames.asScalar();
    DataField dataField;
    Boolean categorical = is_nominal.getValue(variableName);
    if ((Boolean.TRUE).equals(categorical)) {
        this.miningFunction = MiningFunction.CLASSIFICATION;
        RExp targetVariable = variables.getValue(variableName);
        RStringVector targetVariableClass = (RStringVector) targetVariable.getAttributeValue("class");
        RStringVector targetCategories = (RStringVector) levels.getValue(variableName);
        dataField = encoder.createDataField(FieldName.create(variableName), OpType.CATEGORICAL, RExpUtil.getDataType(targetVariableClass.asScalar()), targetCategories.getValues());
    } else if ((Boolean.FALSE).equals(categorical)) {
        this.miningFunction = MiningFunction.REGRESSION;
        dataField = encoder.createDataField(FieldName.create(variableName), OpType.CONTINUOUS, DataType.DOUBLE);
    } else {
        throw new IllegalArgumentException();
    }
    encoder.setLabel(dataField);
}
Also used : DataField(org.dmg.pmml.DataField)

Example 8 with DataField

use of org.dmg.pmml.DataField in project jpmml-r by jpmml.

the class ElmNNConverter method encodeSchema.

@Override
public void encodeSchema(RExpEncoder encoder) {
    RGenericVector elmNN = getObject();
    final RGenericVector model;
    try {
        model = (RGenericVector) elmNN.getValue("model");
    } catch (IllegalArgumentException iae) {
        throw new IllegalArgumentException("No model frame information. Please initialize the \'model\' element", iae);
    }
    RExp terms = model.getAttributeValue("terms");
    RIntegerVector response = (RIntegerVector) terms.getAttributeValue("response");
    RStringVector columns = (RStringVector) terms.getAttributeValue("columns");
    FormulaContext context = new ModelFrameFormulaContext(model);
    Formula formula = FormulaUtil.createFormula(terms, context, encoder);
    // Dependent variable
    int responseIndex = response.asScalar();
    if (responseIndex != 0) {
        DataField dataField = (DataField) formula.getField(responseIndex - 1);
        encoder.setLabel(dataField);
    }
    // Independent variables
    for (int i = 0; i < columns.size(); i++) {
        String column = columns.getValue(i);
        if (i == 0 && "(Intercept)".equals(column)) {
            continue;
        }
        Feature feature = formula.resolveFeature(column);
        encoder.addFeature(feature);
    }
}
Also used : DataField(org.dmg.pmml.DataField) Feature(org.jpmml.converter.Feature)

Example 9 with DataField

use of org.dmg.pmml.DataField in project jpmml-r by jpmml.

the class RandomForestConverter method encodeNonFormula.

private void encodeNonFormula(RExpEncoder encoder) {
    RGenericVector randomForest = getObject();
    RGenericVector forest = (RGenericVector) randomForest.getValue("forest");
    RNumberVector<?> y = (RNumberVector<?>) randomForest.getValue("y", true);
    RStringVector xNames = (RStringVector) randomForest.getValue("xNames", true);
    RNumberVector<?> ncat = (RNumberVector<?>) forest.getValue("ncat");
    RGenericVector xlevels = (RGenericVector) forest.getValue("xlevels");
    if (xNames == null) {
        xNames = xlevels.names();
    }
    // Dependent variable
    {
        FieldName name = FieldName.create("_target");
        DataField dataField;
        if (y instanceof RIntegerVector) {
            dataField = encoder.createDataField(name, OpType.CATEGORICAL, null, RExpUtil.getFactorLevels(y));
        } else {
            dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.DOUBLE);
        }
        encoder.setLabel(dataField);
    }
    // Independernt variables
    for (int i = 0; i < ncat.size(); i++) {
        FieldName name = FieldName.create(xNames.getValue(i));
        DataField dataField;
        boolean categorical = ((ncat.getValue(i)).doubleValue() > 1d);
        if (categorical) {
            RStringVector levels = (RStringVector) xlevels.getValue(i);
            dataField = encoder.createDataField(name, OpType.CATEGORICAL, null, levels.getValues());
        } else {
            dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.DOUBLE);
        }
        encoder.addFeature(dataField);
    }
}
Also used : DataField(org.dmg.pmml.DataField) FieldName(org.dmg.pmml.FieldName)

Example 10 with DataField

use of org.dmg.pmml.DataField in project jpmml-sparkml by jpmml.

the class ModelConverter method encodeSchema.

public Schema encodeSchema(SparkMLEncoder encoder) {
    T model = getTransformer();
    Label label = null;
    if (model instanceof HasLabelCol) {
        HasLabelCol hasLabelCol = (HasLabelCol) model;
        String labelCol = hasLabelCol.getLabelCol();
        Feature feature = encoder.getOnlyFeature(labelCol);
        MiningFunction miningFunction = getMiningFunction();
        switch(miningFunction) {
            case CLASSIFICATION:
                {
                    if (feature instanceof CategoricalFeature) {
                        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
                        DataField dataField = encoder.getDataField(categoricalFeature.getName());
                        label = new CategoricalLabel(dataField);
                    } else if (feature instanceof ContinuousFeature) {
                        ContinuousFeature continuousFeature = (ContinuousFeature) feature;
                        int numClasses = 2;
                        if (model instanceof ClassificationModel) {
                            ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
                            numClasses = classificationModel.numClasses();
                        }
                        List<String> categories = new ArrayList<>();
                        for (int i = 0; i < numClasses; i++) {
                            categories.add(String.valueOf(i));
                        }
                        Field<?> field = encoder.toCategorical(continuousFeature.getName(), categories);
                        encoder.putOnlyFeature(labelCol, new CategoricalFeature(encoder, field, categories));
                        label = new CategoricalLabel(field.getName(), field.getDataType(), categories);
                    } else {
                        throw new IllegalArgumentException("Expected a categorical or categorical-like continuous feature, got " + feature);
                    }
                }
                break;
            case REGRESSION:
                {
                    Field<?> field = encoder.toContinuous(feature.getName());
                    field.setDataType(DataType.DOUBLE);
                    label = new ContinuousLabel(field.getName(), field.getDataType());
                }
                break;
            default:
                throw new IllegalArgumentException("Mining function " + miningFunction + " is not supported");
        }
    }
    if (model instanceof ClassificationModel) {
        ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
        CategoricalLabel categoricalLabel = (CategoricalLabel) label;
        int numClasses = classificationModel.numClasses();
        if (numClasses != categoricalLabel.size()) {
            throw new IllegalArgumentException("Expected " + numClasses + " target categories, got " + categoricalLabel.size() + " target categories");
        }
    }
    String featuresCol = model.getFeaturesCol();
    List<Feature> features = encoder.getFeatures(featuresCol);
    if (model instanceof PredictionModel) {
        PredictionModel<?, ?> predictionModel = (PredictionModel<?, ?>) model;
        int numFeatures = predictionModel.numFeatures();
        if (numFeatures != -1 && features.size() != numFeatures) {
            throw new IllegalArgumentException("Expected " + numFeatures + " features, got " + features.size() + " features");
        }
    }
    Schema result = new Schema(label, features);
    return result;
}
Also used : Schema(org.jpmml.converter.Schema) ContinuousLabel(org.jpmml.converter.ContinuousLabel) CategoricalLabel(org.jpmml.converter.CategoricalLabel) Label(org.jpmml.converter.Label) ArrayList(java.util.ArrayList) PredictionModel(org.apache.spark.ml.PredictionModel) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) HasLabelCol(org.apache.spark.ml.param.shared.HasLabelCol) OutputField(org.dmg.pmml.OutputField) Field(org.dmg.pmml.Field) DataField(org.dmg.pmml.DataField) ContinuousFeature(org.jpmml.converter.ContinuousFeature) DataField(org.dmg.pmml.DataField) CategoricalLabel(org.jpmml.converter.CategoricalLabel) MiningFunction(org.dmg.pmml.MiningFunction) ClassificationModel(org.apache.spark.ml.classification.ClassificationModel) ContinuousLabel(org.jpmml.converter.ContinuousLabel)

Aggregations

DataField (org.dmg.pmml.DataField)26 Feature (org.jpmml.converter.Feature)13 FieldName (org.dmg.pmml.FieldName)12 ArrayList (java.util.ArrayList)9 ContinuousFeature (org.jpmml.converter.ContinuousFeature)8 CategoricalFeature (org.jpmml.converter.CategoricalFeature)5 DataType (org.dmg.pmml.DataType)4 DerivedField (org.dmg.pmml.DerivedField)4 OpType (org.dmg.pmml.OpType)4 Apply (org.dmg.pmml.Apply)3 CategoricalLabel (org.jpmml.converter.CategoricalLabel)3 ContinuousLabel (org.jpmml.converter.ContinuousLabel)3 Label (org.jpmml.converter.Label)3 Function (com.google.common.base.Function)2 MiningFunction (org.dmg.pmml.MiningFunction)2 BooleanFeature (org.jpmml.converter.BooleanFeature)2 InputField (org.jpmml.evaluator.InputField)2 OutputField (org.jpmml.evaluator.OutputField)2 TargetField (org.jpmml.evaluator.TargetField)2 Field (org.openscoring.common.Field)2