Search in sources :

Example 6 with PipelineModel

use of org.apache.spark.ml.PipelineModel in project jpmml-sparkml by jpmml.

the class ConverterUtil method getTransformers.

private static Iterable<Transformer> getTransformers(PipelineModel pipelineModel) {
    List<Transformer> transformers = new ArrayList<>();
    transformers.add(pipelineModel);
    Function<Transformer, List<Transformer>> function = new Function<Transformer, List<Transformer>>() {

        @Override
        public List<Transformer> apply(Transformer transformer) {
            if (transformer instanceof PipelineModel) {
                PipelineModel pipelineModel = (PipelineModel) transformer;
                return Arrays.asList(pipelineModel.stages());
            } else if (transformer instanceof CrossValidatorModel) {
                CrossValidatorModel crossValidatorModel = (CrossValidatorModel) transformer;
                return Collections.<Transformer>singletonList(crossValidatorModel.bestModel());
            } else if (transformer instanceof TrainValidationSplitModel) {
                TrainValidationSplitModel trainValidationSplitModel = (TrainValidationSplitModel) transformer;
                return Collections.<Transformer>singletonList(trainValidationSplitModel.bestModel());
            }
            return null;
        }
    };
    while (true) {
        ListIterator<Transformer> transformerIt = transformers.listIterator();
        boolean modified = false;
        while (transformerIt.hasNext()) {
            Transformer transformer = transformerIt.next();
            List<Transformer> childTransformers = function.apply(transformer);
            if (childTransformers != null) {
                transformerIt.remove();
                for (Transformer childTransformer : childTransformers) {
                    transformerIt.add(childTransformer);
                }
                modified = true;
            }
        }
        if (!modified) {
            break;
        }
    }
    return transformers;
}
Also used : Function(com.google.common.base.Function) CrossValidatorModel(org.apache.spark.ml.tuning.CrossValidatorModel) Transformer(org.apache.spark.ml.Transformer) TrainValidationSplitModel(org.apache.spark.ml.tuning.TrainValidationSplitModel) ArrayList(java.util.ArrayList) ArrayList(java.util.ArrayList) List(java.util.List) PipelineModel(org.apache.spark.ml.PipelineModel)

Example 7 with PipelineModel

use of org.apache.spark.ml.PipelineModel in project jpmml-sparkml by jpmml.

the class ConverterUtil method toPMML.

public static PMML toPMML(StructType schema, PipelineModel pipelineModel) {
    checkVersion();
    SparkMLEncoder encoder = new SparkMLEncoder(schema);
    List<org.dmg.pmml.Model> models = new ArrayList<>();
    Iterable<Transformer> transformers = getTransformers(pipelineModel);
    for (Transformer transformer : transformers) {
        TransformerConverter<?> converter = ConverterUtil.createConverter(transformer);
        if (converter instanceof FeatureConverter) {
            FeatureConverter<?> featureConverter = (FeatureConverter<?>) converter;
            featureConverter.registerFeatures(encoder);
        } else if (converter instanceof ModelConverter) {
            ModelConverter<?> modelConverter = (ModelConverter<?>) converter;
            org.dmg.pmml.Model model = modelConverter.registerModel(encoder);
            models.add(model);
        } else {
            throw new IllegalArgumentException("Expected a " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + " instance, got " + converter);
        }
    }
    org.dmg.pmml.Model rootModel;
    if (models.size() == 1) {
        rootModel = Iterables.getOnlyElement(models);
    } else if (models.size() > 1) {
        List<MiningField> targetMiningFields = new ArrayList<>();
        for (org.dmg.pmml.Model model : models) {
            MiningSchema miningSchema = model.getMiningSchema();
            List<MiningField> miningFields = miningSchema.getMiningFields();
            for (MiningField miningField : miningFields) {
                MiningField.UsageType usageType = miningField.getUsageType();
                switch(usageType) {
                    case PREDICTED:
                    case TARGET:
                        targetMiningFields.add(miningField);
                        break;
                    default:
                        break;
                }
            }
        }
        MiningSchema miningSchema = new MiningSchema(targetMiningFields);
        MiningModel miningModel = MiningModelUtil.createModelChain(models, new Schema(null, Collections.<Feature>emptyList())).setMiningSchema(miningSchema);
        rootModel = miningModel;
    } else {
        throw new IllegalArgumentException("Expected a pipeline with one or more models, got a pipeline with zero models");
    }
    PMML pmml = encoder.encodePMML(rootModel);
    return pmml;
}
Also used : MiningField(org.dmg.pmml.MiningField) Transformer(org.apache.spark.ml.Transformer) MiningSchema(org.dmg.pmml.MiningSchema) Schema(org.jpmml.converter.Schema) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) MiningSchema(org.dmg.pmml.MiningSchema) MiningModel(org.dmg.pmml.mining.MiningModel) MiningModel(org.dmg.pmml.mining.MiningModel) PipelineModel(org.apache.spark.ml.PipelineModel) TrainValidationSplitModel(org.apache.spark.ml.tuning.TrainValidationSplitModel) CrossValidatorModel(org.apache.spark.ml.tuning.CrossValidatorModel) PMML(org.dmg.pmml.PMML) ArrayList(java.util.ArrayList) List(java.util.List)

Example 8 with PipelineModel

use of org.apache.spark.ml.PipelineModel in project jpmml-sparkml by jpmml.

the class ConverterTest method createBatch.

@Override
protected ArchiveBatch createBatch(String name, String dataset, Predicate<FieldName> predicate) {
    Predicate<FieldName> excludePredictionFields = excludeFields(FieldName.create("prediction"), FieldName.create("pmml(prediction)"));
    if (predicate == null) {
        predicate = excludePredictionFields;
    } else {
        predicate = Predicates.and(predicate, excludePredictionFields);
    }
    ArchiveBatch result = new IntegrationTestBatch(name, dataset, predicate) {

        @Override
        public IntegrationTest getIntegrationTest() {
            return ConverterTest.this;
        }

        @Override
        public PMML getPMML() throws Exception {
            StructType schema;
            try (InputStream is = open("/schema/" + getDataset() + ".json")) {
                String json = CharStreams.toString(new InputStreamReader(is, "UTF-8"));
                schema = (StructType) DataType.fromJson(json);
            }
            PipelineModel pipelineModel;
            try (InputStream is = open("/pipeline/" + getName() + getDataset() + ".zip")) {
                File tmpZipFile = File.createTempFile(getName() + getDataset(), ".zip");
                try (OutputStream os = new FileOutputStream(tmpZipFile)) {
                    ByteStreams.copy(is, os);
                }
                File tmpDir = File.createTempFile(getName() + getDataset(), "");
                if (!tmpDir.delete()) {
                    throw new IOException();
                }
                ZipUtil.uncompress(tmpZipFile, tmpDir);
                MLReader<PipelineModel> mlReader = new PipelineModel.PipelineModelReader();
                mlReader.session(ConverterTest.sparkSession);
                pipelineModel = mlReader.load(tmpDir.getAbsolutePath());
            }
            PMML pmml = ConverterUtil.toPMML(schema, pipelineModel);
            ensureValidity(pmml);
            return pmml;
        }
    };
    return result;
}
Also used : IntegrationTestBatch(org.jpmml.evaluator.IntegrationTestBatch) StructType(org.apache.spark.sql.types.StructType) InputStreamReader(java.io.InputStreamReader) InputStream(java.io.InputStream) OutputStream(java.io.OutputStream) FileOutputStream(java.io.FileOutputStream) IOException(java.io.IOException) PipelineModel(org.apache.spark.ml.PipelineModel) ArchiveBatch(org.jpmml.evaluator.ArchiveBatch) FileOutputStream(java.io.FileOutputStream) PMML(org.dmg.pmml.PMML) FieldName(org.dmg.pmml.FieldName) File(java.io.File)

Aggregations

PipelineModel (org.apache.spark.ml.PipelineModel)8 Pipeline (org.apache.spark.ml.Pipeline)3 Transformer (org.apache.spark.ml.Transformer)3 PMML (org.dmg.pmml.PMML)3 File (java.io.File)2 FileOutputStream (java.io.FileOutputStream)2 IOException (java.io.IOException)2 InputStream (java.io.InputStream)2 InputStreamReader (java.io.InputStreamReader)2 OutputStream (java.io.OutputStream)2 ArrayList (java.util.ArrayList)2 LinkedHashMap (java.util.LinkedHashMap)2 List (java.util.List)2 CrossValidatorModel (org.apache.spark.ml.tuning.CrossValidatorModel)2 TrainValidationSplitModel (org.apache.spark.ml.tuning.TrainValidationSplitModel)2 Dataset (org.apache.spark.sql.Dataset)2 Row (org.apache.spark.sql.Row)2 StructType (org.apache.spark.sql.types.StructType)2 Feature (org.jpmml.converter.Feature)2 Function (com.google.common.base.Function)1