Search in sources :

Example 1 with HasPredictionCol

use of org.apache.spark.ml.param.shared.HasPredictionCol in project jpmml-sparkml by jpmml.

the class PMMLBuilder method build.

public PMML build() {
    StructType schema = getSchema();
    PipelineModel pipelineModel = getPipelineModel();
    Map<RegexKey, ? extends Map<String, ?>> options = getOptions();
    Verification verification = getVerification();
    ConverterFactory converterFactory = new ConverterFactory(options);
    SparkMLEncoder encoder = new SparkMLEncoder(schema, converterFactory);
    Map<FieldName, DerivedField> derivedFields = encoder.getDerivedFields();
    List<org.dmg.pmml.Model> models = new ArrayList<>();
    List<String> predictionColumns = new ArrayList<>();
    List<String> probabilityColumns = new ArrayList<>();
    // Transformations preceding the last model
    List<FieldName> preProcessorNames = Collections.emptyList();
    Iterable<Transformer> transformers = getTransformers(pipelineModel);
    for (Transformer transformer : transformers) {
        TransformerConverter<?> converter = converterFactory.newConverter(transformer);
        if (converter instanceof FeatureConverter) {
            FeatureConverter<?> featureConverter = (FeatureConverter<?>) converter;
            featureConverter.registerFeatures(encoder);
        } else if (converter instanceof ModelConverter) {
            ModelConverter<?> modelConverter = (ModelConverter<?>) converter;
            org.dmg.pmml.Model model = modelConverter.registerModel(encoder);
            models.add(model);
            featureImportances: if (modelConverter instanceof HasFeatureImportances) {
                HasFeatureImportances hasFeatureImportances = (HasFeatureImportances) modelConverter;
                Boolean estimateFeatureImportances = (Boolean) modelConverter.getOption(HasTreeOptions.OPTION_ESTIMATE_FEATURE_IMPORTANCES, Boolean.FALSE);
                if (!estimateFeatureImportances) {
                    break featureImportances;
                }
                List<Double> featureImportances = VectorUtil.toList(hasFeatureImportances.getFeatureImportances());
                List<Feature> features = modelConverter.getFeatures(encoder);
                SchemaUtil.checkSize(featureImportances.size(), features);
                for (int i = 0; i < featureImportances.size(); i++) {
                    Double featureImportance = featureImportances.get(i);
                    Feature feature = features.get(i);
                    encoder.addFeatureImportance(model, feature, featureImportance);
                }
            }
            hasPredictionCol: if (transformer instanceof HasPredictionCol) {
                HasPredictionCol hasPredictionCol = (HasPredictionCol) transformer;
                // XXX
                if ((transformer instanceof GeneralizedLinearRegressionModel) && (MiningFunction.CLASSIFICATION).equals(model.getMiningFunction())) {
                    break hasPredictionCol;
                }
                predictionColumns.add(hasPredictionCol.getPredictionCol());
            }
            if (transformer instanceof HasProbabilityCol) {
                HasProbabilityCol hasProbabilityCol = (HasProbabilityCol) transformer;
                probabilityColumns.add(hasProbabilityCol.getProbabilityCol());
            }
            preProcessorNames = new ArrayList<>(derivedFields.keySet());
        } else {
            throw new IllegalArgumentException("Expected a subclass of " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + ", got " + (converter != null ? ("class " + (converter.getClass()).getName()) : null));
        }
    }
    // Transformations following the last model
    List<FieldName> postProcessorNames = new ArrayList<>(derivedFields.keySet());
    postProcessorNames.removeAll(preProcessorNames);
    org.dmg.pmml.Model model;
    if (models.size() == 0) {
        model = null;
    } else if (models.size() == 1) {
        model = Iterables.getOnlyElement(models);
    } else {
        model = MiningModelUtil.createModelChain(models);
    }
    if ((model != null) && (postProcessorNames.size() > 0)) {
        org.dmg.pmml.Model finalModel = MiningModelUtil.getFinalModel(model);
        Output output = ModelUtil.ensureOutput(finalModel);
        for (FieldName postProcessorName : postProcessorNames) {
            DerivedField derivedField = derivedFields.get(postProcessorName);
            encoder.removeDerivedField(postProcessorName);
            OutputField outputField = new OutputField(derivedField.getName(), derivedField.getOpType(), derivedField.getDataType()).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression(derivedField.getExpression());
            output.addOutputFields(outputField);
        }
    }
    PMML pmml = encoder.encodePMML(model);
    if ((model != null) && (predictionColumns.size() > 0 || probabilityColumns.size() > 0) && (verification != null)) {
        Dataset<Row> dataset = verification.getDataset();
        Dataset<Row> transformedDataset = verification.getTransformedDataset();
        Double precision = verification.getPrecision();
        Double zeroThreshold = verification.getZeroThreshold();
        List<String> inputColumns = new ArrayList<>();
        MiningSchema miningSchema = model.getMiningSchema();
        List<MiningField> miningFields = miningSchema.getMiningFields();
        for (MiningField miningField : miningFields) {
            MiningField.UsageType usageType = miningField.getUsageType();
            switch(usageType) {
                case ACTIVE:
                    FieldName name = miningField.getName();
                    inputColumns.add(name.getValue());
                    break;
                default:
                    break;
            }
        }
        Map<VerificationField, List<?>> data = new LinkedHashMap<>();
        for (String inputColumn : inputColumns) {
            VerificationField verificationField = ModelUtil.createVerificationField(FieldName.create(inputColumn));
            data.put(verificationField, getColumn(dataset, inputColumn));
        }
        for (String predictionColumn : predictionColumns) {
            Feature feature = encoder.getOnlyFeature(predictionColumn);
            VerificationField verificationField = ModelUtil.createVerificationField(feature.getName()).setPrecision(precision).setZeroThreshold(zeroThreshold);
            data.put(verificationField, getColumn(transformedDataset, predictionColumn));
        }
        for (String probabilityColumn : probabilityColumns) {
            List<Feature> features = encoder.getFeatures(probabilityColumn);
            for (int i = 0; i < features.size(); i++) {
                Feature feature = features.get(i);
                VerificationField verificationField = ModelUtil.createVerificationField(feature.getName()).setPrecision(precision).setZeroThreshold(zeroThreshold);
                data.put(verificationField, getVectorColumn(transformedDataset, probabilityColumn, i));
            }
        }
        model.setModelVerification(ModelUtil.createModelVerification(data));
    }
    return pmml;
}
Also used : MiningField(org.dmg.pmml.MiningField) StructType(org.apache.spark.sql.types.StructType) HasProbabilityCol(org.apache.spark.ml.param.shared.HasProbabilityCol) GeneralizedLinearRegressionModel(org.apache.spark.ml.regression.GeneralizedLinearRegressionModel) ArrayList(java.util.ArrayList) ResultFeature(org.dmg.pmml.ResultFeature) Feature(org.jpmml.converter.Feature) LinkedHashMap(java.util.LinkedHashMap) HasFeatureImportances(org.jpmml.sparkml.model.HasFeatureImportances) ArrayList(java.util.ArrayList) List(java.util.List) HasPredictionCol(org.apache.spark.ml.param.shared.HasPredictionCol) MiningSchema(org.dmg.pmml.MiningSchema) OutputField(org.dmg.pmml.OutputField) Row(org.apache.spark.sql.Row) Transformer(org.apache.spark.ml.Transformer) PipelineModel(org.apache.spark.ml.PipelineModel) Output(org.dmg.pmml.Output) FieldName(org.dmg.pmml.FieldName) VerificationField(org.dmg.pmml.VerificationField) GeneralizedLinearRegressionModel(org.apache.spark.ml.regression.GeneralizedLinearRegressionModel) PipelineModel(org.apache.spark.ml.PipelineModel) TrainValidationSplitModel(org.apache.spark.ml.tuning.TrainValidationSplitModel) CrossValidatorModel(org.apache.spark.ml.tuning.CrossValidatorModel) PMML(org.dmg.pmml.PMML) DerivedField(org.dmg.pmml.DerivedField)

Aggregations

ArrayList (java.util.ArrayList)1 LinkedHashMap (java.util.LinkedHashMap)1 List (java.util.List)1 PipelineModel (org.apache.spark.ml.PipelineModel)1 Transformer (org.apache.spark.ml.Transformer)1 HasPredictionCol (org.apache.spark.ml.param.shared.HasPredictionCol)1 HasProbabilityCol (org.apache.spark.ml.param.shared.HasProbabilityCol)1 GeneralizedLinearRegressionModel (org.apache.spark.ml.regression.GeneralizedLinearRegressionModel)1 CrossValidatorModel (org.apache.spark.ml.tuning.CrossValidatorModel)1 TrainValidationSplitModel (org.apache.spark.ml.tuning.TrainValidationSplitModel)1 Row (org.apache.spark.sql.Row)1 StructType (org.apache.spark.sql.types.StructType)1 DerivedField (org.dmg.pmml.DerivedField)1 FieldName (org.dmg.pmml.FieldName)1 MiningField (org.dmg.pmml.MiningField)1 MiningSchema (org.dmg.pmml.MiningSchema)1 Output (org.dmg.pmml.Output)1 OutputField (org.dmg.pmml.OutputField)1 PMML (org.dmg.pmml.PMML)1 ResultFeature (org.dmg.pmml.ResultFeature)1