Search in sources :

Example 1 with Transformer

use of org.apache.spark.ml.Transformer in project jpmml-sparkml by jpmml.

the class RFormulaModelConverter method registerFeatures.

@Override
public void registerFeatures(SparkMLEncoder encoder) {
    RFormulaModel transformer = getTransformer();
    ResolvedRFormula resolvedFormula = transformer.resolvedFormula();
    String targetCol = resolvedFormula.label();
    String labelCol = transformer.getLabelCol();
    if (!(targetCol).equals(labelCol)) {
        List<Feature> features = encoder.getFeatures(targetCol);
        encoder.putFeatures(labelCol, features);
    }
    PipelineModel pipelineModel = transformer.pipelineModel();
    Transformer[] stages = pipelineModel.stages();
    for (Transformer stage : stages) {
        FeatureConverter<?> featureConverter = ConverterUtil.createFeatureConverter(stage);
        featureConverter.registerFeatures(encoder);
    }
}
Also used : Transformer(org.apache.spark.ml.Transformer) ResolvedRFormula(org.apache.spark.ml.feature.ResolvedRFormula) RFormulaModel(org.apache.spark.ml.feature.RFormulaModel) Feature(org.jpmml.converter.Feature) PipelineModel(org.apache.spark.ml.PipelineModel)

Example 2 with Transformer

use of org.apache.spark.ml.Transformer in project jpmml-sparkml by jpmml.

the class ConverterUtil method getTransformers.

private static Iterable<Transformer> getTransformers(PipelineModel pipelineModel) {
    List<Transformer> transformers = new ArrayList<>();
    transformers.add(pipelineModel);
    Function<Transformer, List<Transformer>> function = new Function<Transformer, List<Transformer>>() {

        @Override
        public List<Transformer> apply(Transformer transformer) {
            if (transformer instanceof PipelineModel) {
                PipelineModel pipelineModel = (PipelineModel) transformer;
                return Arrays.asList(pipelineModel.stages());
            } else if (transformer instanceof CrossValidatorModel) {
                CrossValidatorModel crossValidatorModel = (CrossValidatorModel) transformer;
                return Collections.<Transformer>singletonList(crossValidatorModel.bestModel());
            } else if (transformer instanceof TrainValidationSplitModel) {
                TrainValidationSplitModel trainValidationSplitModel = (TrainValidationSplitModel) transformer;
                return Collections.<Transformer>singletonList(trainValidationSplitModel.bestModel());
            }
            return null;
        }
    };
    while (true) {
        ListIterator<Transformer> transformerIt = transformers.listIterator();
        boolean modified = false;
        while (transformerIt.hasNext()) {
            Transformer transformer = transformerIt.next();
            List<Transformer> childTransformers = function.apply(transformer);
            if (childTransformers != null) {
                transformerIt.remove();
                for (Transformer childTransformer : childTransformers) {
                    transformerIt.add(childTransformer);
                }
                modified = true;
            }
        }
        if (!modified) {
            break;
        }
    }
    return transformers;
}
Also used : Function(com.google.common.base.Function) CrossValidatorModel(org.apache.spark.ml.tuning.CrossValidatorModel) Transformer(org.apache.spark.ml.Transformer) TrainValidationSplitModel(org.apache.spark.ml.tuning.TrainValidationSplitModel) ArrayList(java.util.ArrayList) ArrayList(java.util.ArrayList) List(java.util.List) PipelineModel(org.apache.spark.ml.PipelineModel)

Example 3 with Transformer

use of org.apache.spark.ml.Transformer in project jpmml-sparkml by jpmml.

the class ConverterUtil method toPMML.

public static PMML toPMML(StructType schema, PipelineModel pipelineModel) {
    checkVersion();
    SparkMLEncoder encoder = new SparkMLEncoder(schema);
    List<org.dmg.pmml.Model> models = new ArrayList<>();
    Iterable<Transformer> transformers = getTransformers(pipelineModel);
    for (Transformer transformer : transformers) {
        TransformerConverter<?> converter = ConverterUtil.createConverter(transformer);
        if (converter instanceof FeatureConverter) {
            FeatureConverter<?> featureConverter = (FeatureConverter<?>) converter;
            featureConverter.registerFeatures(encoder);
        } else if (converter instanceof ModelConverter) {
            ModelConverter<?> modelConverter = (ModelConverter<?>) converter;
            org.dmg.pmml.Model model = modelConverter.registerModel(encoder);
            models.add(model);
        } else {
            throw new IllegalArgumentException("Expected a " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + " instance, got " + converter);
        }
    }
    org.dmg.pmml.Model rootModel;
    if (models.size() == 1) {
        rootModel = Iterables.getOnlyElement(models);
    } else if (models.size() > 1) {
        List<MiningField> targetMiningFields = new ArrayList<>();
        for (org.dmg.pmml.Model model : models) {
            MiningSchema miningSchema = model.getMiningSchema();
            List<MiningField> miningFields = miningSchema.getMiningFields();
            for (MiningField miningField : miningFields) {
                MiningField.UsageType usageType = miningField.getUsageType();
                switch(usageType) {
                    case PREDICTED:
                    case TARGET:
                        targetMiningFields.add(miningField);
                        break;
                    default:
                        break;
                }
            }
        }
        MiningSchema miningSchema = new MiningSchema(targetMiningFields);
        MiningModel miningModel = MiningModelUtil.createModelChain(models, new Schema(null, Collections.<Feature>emptyList())).setMiningSchema(miningSchema);
        rootModel = miningModel;
    } else {
        throw new IllegalArgumentException("Expected a pipeline with one or more models, got a pipeline with zero models");
    }
    PMML pmml = encoder.encodePMML(rootModel);
    return pmml;
}
Also used : MiningField(org.dmg.pmml.MiningField) Transformer(org.apache.spark.ml.Transformer) MiningSchema(org.dmg.pmml.MiningSchema) Schema(org.jpmml.converter.Schema) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) MiningSchema(org.dmg.pmml.MiningSchema) MiningModel(org.dmg.pmml.mining.MiningModel) MiningModel(org.dmg.pmml.mining.MiningModel) PipelineModel(org.apache.spark.ml.PipelineModel) TrainValidationSplitModel(org.apache.spark.ml.tuning.TrainValidationSplitModel) CrossValidatorModel(org.apache.spark.ml.tuning.CrossValidatorModel) PMML(org.dmg.pmml.PMML) ArrayList(java.util.ArrayList) List(java.util.List)

Example 4 with Transformer

use of org.apache.spark.ml.Transformer in project jpmml-sparkml by jpmml.

the class FeatureConverter method registerFeatures.

public void registerFeatures(SparkMLEncoder encoder) {
    Transformer transformer = getTransformer();
    if (transformer instanceof HasOutputCol) {
        HasOutputCol hasOutputCol = (HasOutputCol) transformer;
        String outputCol = hasOutputCol.getOutputCol();
        List<Feature> features = encodeFeatures(encoder);
        encoder.putFeatures(outputCol, features);
    } else if (transformer instanceof HasOutputCols) {
        HasOutputCols hasOutputCols = (HasOutputCols) transformer;
        String[] outputCols = hasOutputCols.getOutputCols();
        List<Feature> features = encodeFeatures(encoder);
        if (outputCols.length != features.size()) {
            throw new IllegalArgumentException("Expected " + outputCols.length + " features, got " + features.size() + " features");
        }
        for (int i = 0; i < outputCols.length; i++) {
            String outputCol = outputCols[i];
            Feature feature = features.get(i);
            if (feature instanceof BinarizedCategoricalFeature) {
                BinarizedCategoricalFeature binarizedCategoricalFeature = (BinarizedCategoricalFeature) feature;
                encoder.putFeatures(outputCol, (List) binarizedCategoricalFeature.getBinaryFeatures());
            } else {
                encoder.putOnlyFeature(outputCol, feature);
            }
        }
    }
}
Also used : Transformer(org.apache.spark.ml.Transformer) List(java.util.List) HasOutputCols(org.apache.spark.ml.param.shared.HasOutputCols) Feature(org.jpmml.converter.Feature) HasOutputCol(org.apache.spark.ml.param.shared.HasOutputCol)

Aggregations

Transformer (org.apache.spark.ml.Transformer)4 List (java.util.List)3 PipelineModel (org.apache.spark.ml.PipelineModel)3 Feature (org.jpmml.converter.Feature)3 ArrayList (java.util.ArrayList)2 CrossValidatorModel (org.apache.spark.ml.tuning.CrossValidatorModel)2 TrainValidationSplitModel (org.apache.spark.ml.tuning.TrainValidationSplitModel)2 Function (com.google.common.base.Function)1 RFormulaModel (org.apache.spark.ml.feature.RFormulaModel)1 ResolvedRFormula (org.apache.spark.ml.feature.ResolvedRFormula)1 HasOutputCol (org.apache.spark.ml.param.shared.HasOutputCol)1 HasOutputCols (org.apache.spark.ml.param.shared.HasOutputCols)1 MiningField (org.dmg.pmml.MiningField)1 MiningSchema (org.dmg.pmml.MiningSchema)1 PMML (org.dmg.pmml.PMML)1 MiningModel (org.dmg.pmml.mining.MiningModel)1 Schema (org.jpmml.converter.Schema)1