Search in sources :

Example 11 with Feature

use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.

the class ModelConverter method encodeSchema.

public Schema encodeSchema(SparkMLEncoder encoder) {
    T model = getTransformer();
    Label label = null;
    if (model instanceof HasLabelCol) {
        HasLabelCol hasLabelCol = (HasLabelCol) model;
        String labelCol = hasLabelCol.getLabelCol();
        Feature feature = encoder.getOnlyFeature(labelCol);
        MiningFunction miningFunction = getMiningFunction();
        switch(miningFunction) {
            case CLASSIFICATION:
                {
                    if (feature instanceof CategoricalFeature) {
                        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
                        DataField dataField = encoder.getDataField(categoricalFeature.getName());
                        label = new CategoricalLabel(dataField);
                    } else if (feature instanceof ContinuousFeature) {
                        ContinuousFeature continuousFeature = (ContinuousFeature) feature;
                        int numClasses = 2;
                        if (model instanceof ClassificationModel) {
                            ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
                            numClasses = classificationModel.numClasses();
                        }
                        List<String> categories = new ArrayList<>();
                        for (int i = 0; i < numClasses; i++) {
                            categories.add(String.valueOf(i));
                        }
                        Field<?> field = encoder.toCategorical(continuousFeature.getName(), categories);
                        encoder.putOnlyFeature(labelCol, new CategoricalFeature(encoder, field, categories));
                        label = new CategoricalLabel(field.getName(), field.getDataType(), categories);
                    } else {
                        throw new IllegalArgumentException("Expected a categorical or categorical-like continuous feature, got " + feature);
                    }
                }
                break;
            case REGRESSION:
                {
                    Field<?> field = encoder.toContinuous(feature.getName());
                    field.setDataType(DataType.DOUBLE);
                    label = new ContinuousLabel(field.getName(), field.getDataType());
                }
                break;
            default:
                throw new IllegalArgumentException("Mining function " + miningFunction + " is not supported");
        }
    }
    if (model instanceof ClassificationModel) {
        ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
        CategoricalLabel categoricalLabel = (CategoricalLabel) label;
        int numClasses = classificationModel.numClasses();
        if (numClasses != categoricalLabel.size()) {
            throw new IllegalArgumentException("Expected " + numClasses + " target categories, got " + categoricalLabel.size() + " target categories");
        }
    }
    String featuresCol = model.getFeaturesCol();
    List<Feature> features = encoder.getFeatures(featuresCol);
    if (model instanceof PredictionModel) {
        PredictionModel<?, ?> predictionModel = (PredictionModel<?, ?>) model;
        int numFeatures = predictionModel.numFeatures();
        if (numFeatures != -1 && features.size() != numFeatures) {
            throw new IllegalArgumentException("Expected " + numFeatures + " features, got " + features.size() + " features");
        }
    }
    Schema result = new Schema(label, features);
    return result;
}
Also used : Schema(org.jpmml.converter.Schema) ContinuousLabel(org.jpmml.converter.ContinuousLabel) CategoricalLabel(org.jpmml.converter.CategoricalLabel) Label(org.jpmml.converter.Label) ArrayList(java.util.ArrayList) PredictionModel(org.apache.spark.ml.PredictionModel) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) HasLabelCol(org.apache.spark.ml.param.shared.HasLabelCol) OutputField(org.dmg.pmml.OutputField) Field(org.dmg.pmml.Field) DataField(org.dmg.pmml.DataField) ContinuousFeature(org.jpmml.converter.ContinuousFeature) DataField(org.dmg.pmml.DataField) CategoricalLabel(org.jpmml.converter.CategoricalLabel) MiningFunction(org.dmg.pmml.MiningFunction) ClassificationModel(org.apache.spark.ml.classification.ClassificationModel) ContinuousLabel(org.jpmml.converter.ContinuousLabel)

Example 12 with Feature

use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.

the class SparkMLEncoder method getFeatures.

public List<Feature> getFeatures(String column) {
    List<Feature> features = this.columnFeatures.get(column);
    if (features == null) {
        FieldName name = FieldName.create(column);
        DataField dataField = getDataField(name);
        if (dataField == null) {
            dataField = createDataField(name);
        }
        Feature feature;
        DataType dataType = dataField.getDataType();
        switch(dataType) {
            case STRING:
                feature = new WildcardFeature(this, dataField);
                break;
            case INTEGER:
            case DOUBLE:
                feature = new ContinuousFeature(this, dataField);
                break;
            case BOOLEAN:
                feature = new BooleanFeature(this, dataField);
                break;
            default:
                throw new IllegalArgumentException("Data type " + dataType + " is not supported");
        }
        return Collections.singletonList(feature);
    }
    return features;
}
Also used : ContinuousFeature(org.jpmml.converter.ContinuousFeature) DataField(org.dmg.pmml.DataField) DataType(org.dmg.pmml.DataType) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) WildcardFeature(org.jpmml.converter.WildcardFeature) FieldName(org.dmg.pmml.FieldName) BooleanFeature(org.jpmml.converter.BooleanFeature) WildcardFeature(org.jpmml.converter.WildcardFeature)

Example 13 with Feature

use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.

the class SparkMLEncoder method getFeatures.

public List<Feature> getFeatures(String column, int[] indices) {
    List<Feature> features = getFeatures(column);
    List<Feature> result = new ArrayList<>();
    for (int i = 0; i < indices.length; i++) {
        int index = indices[i];
        Feature feature = features.get(index);
        result.add(feature);
    }
    return result;
}
Also used : ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) WildcardFeature(org.jpmml.converter.WildcardFeature)

Example 14 with Feature

use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.

the class TermFeature method createApply.

public Apply createApply() {
    DefineFunction defineFunction = getDefineFunction();
    Feature feature = getFeature();
    String value = getValue();
    Constant constant = PMMLUtil.createConstant(value, DataType.STRING);
    return PMMLUtil.createApply(defineFunction.getName(), feature.ref(), constant);
}
Also used : Constant(org.dmg.pmml.Constant) DefineFunction(org.dmg.pmml.DefineFunction) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature)

Example 15 with Feature

use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.

the class StandardScalerModelConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    StandardScalerModel transformer = getTransformer();
    List<Feature> features = encoder.getFeatures(transformer.getInputCol());
    Vector mean = transformer.mean();
    if (transformer.getWithMean() && mean.size() != features.size()) {
        throw new IllegalArgumentException();
    }
    Vector std = transformer.std();
    if (transformer.getWithStd() && std.size() != features.size()) {
        throw new IllegalArgumentException();
    }
    List<Feature> result = new ArrayList<>();
    for (int i = 0; i < features.size(); i++) {
        Feature feature = features.get(i);
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        Expression expression = continuousFeature.ref();
        if (transformer.getWithMean()) {
            double meanValue = mean.apply(i);
            if (!ValueUtil.isZero(meanValue)) {
                expression = PMMLUtil.createApply("-", expression, PMMLUtil.createConstant(meanValue));
            }
        }
        if (transformer.getWithStd()) {
            double stdValue = std.apply(i);
            if (!ValueUtil.isOne(stdValue)) {
                expression = PMMLUtil.createApply("*", expression, PMMLUtil.createConstant(1d / stdValue));
            }
        }
        DerivedField derivedField = encoder.createDerivedField(formatName(transformer, i), OpType.CONTINUOUS, DataType.DOUBLE, expression);
        result.add(new ContinuousFeature(encoder, derivedField));
    }
    return result;
}
Also used : StandardScalerModel(org.apache.spark.ml.feature.StandardScalerModel) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Expression(org.dmg.pmml.Expression) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Vector(org.apache.spark.ml.linalg.Vector) DerivedField(org.dmg.pmml.DerivedField)

Aggregations

Feature (org.jpmml.converter.Feature)53 ContinuousFeature (org.jpmml.converter.ContinuousFeature)30 ArrayList (java.util.ArrayList)27 CategoricalFeature (org.jpmml.converter.CategoricalFeature)19 DerivedField (org.dmg.pmml.DerivedField)14 DataField (org.dmg.pmml.DataField)13 FieldName (org.dmg.pmml.FieldName)10 Apply (org.dmg.pmml.Apply)9 BooleanFeature (org.jpmml.converter.BooleanFeature)9 BinaryFeature (org.jpmml.converter.BinaryFeature)7 List (java.util.List)6 Expression (org.dmg.pmml.Expression)6 SimplePredicate (org.dmg.pmml.SimplePredicate)6 Vector (org.apache.spark.ml.linalg.Vector)5 Predicate (org.dmg.pmml.Predicate)5 Node (org.dmg.pmml.tree.Node)5 DocumentFeature (org.jpmml.sparkml.DocumentFeature)5 InteractionFeature (org.jpmml.converter.InteractionFeature)4 DocumentBuilder (javax.xml.parsers.DocumentBuilder)3 Transformer (org.apache.spark.ml.Transformer)3