Search in sources :

Example 1 with VerificationField

use of org.dmg.pmml.VerificationField in project jpmml-r by jpmml.

the class ModelConverter method encodeVerificationData.

protected static Map<VerificationField, List<?>> encodeVerificationData(List<? extends RExp> columns, List<String> names) {
    Map<VerificationField, List<?>> result = new LinkedHashMap<>();
    for (int i = 0; i < columns.size(); i++) {
        String name = names.get(i);
        RVector<?> column = (RVector<?>) columns.get(i);
        List<?> values;
        if (column instanceof RDoubleVector) {
            Function<Double, Double> function = new Function<Double, Double>() {

                @Override
                public Double apply(Double value) {
                    if (value.isNaN()) {
                        return null;
                    }
                    return value;
                }
            };
            values = Lists.transform((List) column.getValues(), function);
        } else if (column instanceof RFactorVector) {
            RFactorVector factor = (RFactorVector) column;
            values = factor.getFactorValues();
        } else {
            values = column.getValues();
        }
        VerificationField verificationField = ModelUtil.createVerificationField(name);
        result.put(verificationField, values);
    }
    return result;
}
Also used : VerificationField(org.dmg.pmml.VerificationField) LinkedHashMap(java.util.LinkedHashMap) Function(com.google.common.base.Function) List(java.util.List)

Example 2 with VerificationField

use of org.dmg.pmml.VerificationField in project jpmml-r by jpmml.

the class ModelConverter method encodePMML.

@Override
public PMML encodePMML(RExpEncoder encoder) {
    RExp object = getObject();
    RGenericVector verification = null;
    if (object instanceof S4Object) {
        S4Object model = (S4Object) object;
        verification = model.getGenericAttribute("verification", false);
    } else if (object instanceof RGenericVector) {
        RGenericVector model = (RGenericVector) object;
        verification = model.getGenericElement("verification", false);
    }
    encodeSchema(encoder);
    Schema schema = encoder.createSchema();
    Model model = encode(schema);
    verification: if (verification != null) {
        RDoubleVector precision = verification.getDoubleElement("precision");
        RDoubleVector zeroThreshold = verification.getDoubleElement("zeroThreshold");
        VerificationMap data = new VerificationMap(precision.asScalar(), zeroThreshold.asScalar());
        RGenericVector activeValues = verification.getGenericElement("active_values");
        RGenericVector targetValues = verification.getGenericElement("target_values", false);
        RGenericVector outputValues = verification.getGenericElement("output_values", false);
        if (activeValues != null) {
            data.putInputData(encodeActiveValues(activeValues));
        }
        if (targetValues != null && outputValues == null) {
            Label label = schema.getLabel();
            String name = label.getName();
            Collection<VerificationField> verificationFields = data.keySet();
            for (Iterator<VerificationField> verificationFieldIt = verificationFields.iterator(); verificationFieldIt.hasNext(); ) {
                VerificationField verificationField = verificationFieldIt.next();
                if ((verificationField.requireField()).equals(name)) {
                    verificationFieldIt.remove();
                }
            }
            data.putResultData(encodeTargetValues(targetValues, label));
        } else if (outputValues != null) {
            data.putResultData(encodeOutputValues(outputValues));
        } else {
            break verification;
        }
        model.setModelVerification(ModelUtil.createModelVerification(data));
    }
    PMML pmml = encoder.encodePMML(model);
    return pmml;
}
Also used : Schema(org.jpmml.converter.Schema) Label(org.jpmml.converter.Label) VerificationField(org.dmg.pmml.VerificationField) Model(org.dmg.pmml.Model) Iterator(java.util.Iterator) Collection(java.util.Collection) PMML(org.dmg.pmml.PMML)

Example 3 with VerificationField

use of org.dmg.pmml.VerificationField in project jpmml-sparkml by jpmml.

the class PMMLBuilder method build.

public PMML build() {
    StructType schema = getSchema();
    PipelineModel pipelineModel = getPipelineModel();
    Map<RegexKey, ? extends Map<String, ?>> options = getOptions();
    Verification verification = getVerification();
    ConverterFactory converterFactory = new ConverterFactory(options);
    SparkMLEncoder encoder = new SparkMLEncoder(schema, converterFactory);
    Map<FieldName, DerivedField> derivedFields = encoder.getDerivedFields();
    List<org.dmg.pmml.Model> models = new ArrayList<>();
    List<String> predictionColumns = new ArrayList<>();
    List<String> probabilityColumns = new ArrayList<>();
    // Transformations preceding the last model
    List<FieldName> preProcessorNames = Collections.emptyList();
    Iterable<Transformer> transformers = getTransformers(pipelineModel);
    for (Transformer transformer : transformers) {
        TransformerConverter<?> converter = converterFactory.newConverter(transformer);
        if (converter instanceof FeatureConverter) {
            FeatureConverter<?> featureConverter = (FeatureConverter<?>) converter;
            featureConverter.registerFeatures(encoder);
        } else if (converter instanceof ModelConverter) {
            ModelConverter<?> modelConverter = (ModelConverter<?>) converter;
            org.dmg.pmml.Model model = modelConverter.registerModel(encoder);
            models.add(model);
            featureImportances: if (modelConverter instanceof HasFeatureImportances) {
                HasFeatureImportances hasFeatureImportances = (HasFeatureImportances) modelConverter;
                Boolean estimateFeatureImportances = (Boolean) modelConverter.getOption(HasTreeOptions.OPTION_ESTIMATE_FEATURE_IMPORTANCES, Boolean.FALSE);
                if (!estimateFeatureImportances) {
                    break featureImportances;
                }
                List<Double> featureImportances = VectorUtil.toList(hasFeatureImportances.getFeatureImportances());
                List<Feature> features = modelConverter.getFeatures(encoder);
                SchemaUtil.checkSize(featureImportances.size(), features);
                for (int i = 0; i < featureImportances.size(); i++) {
                    Double featureImportance = featureImportances.get(i);
                    Feature feature = features.get(i);
                    encoder.addFeatureImportance(model, feature, featureImportance);
                }
            }
            hasPredictionCol: if (transformer instanceof HasPredictionCol) {
                HasPredictionCol hasPredictionCol = (HasPredictionCol) transformer;
                // XXX
                if ((transformer instanceof GeneralizedLinearRegressionModel) && (MiningFunction.CLASSIFICATION).equals(model.getMiningFunction())) {
                    break hasPredictionCol;
                }
                predictionColumns.add(hasPredictionCol.getPredictionCol());
            }
            if (transformer instanceof HasProbabilityCol) {
                HasProbabilityCol hasProbabilityCol = (HasProbabilityCol) transformer;
                probabilityColumns.add(hasProbabilityCol.getProbabilityCol());
            }
            preProcessorNames = new ArrayList<>(derivedFields.keySet());
        } else {
            throw new IllegalArgumentException("Expected a subclass of " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + ", got " + (converter != null ? ("class " + (converter.getClass()).getName()) : null));
        }
    }
    // Transformations following the last model
    List<FieldName> postProcessorNames = new ArrayList<>(derivedFields.keySet());
    postProcessorNames.removeAll(preProcessorNames);
    org.dmg.pmml.Model model;
    if (models.size() == 0) {
        model = null;
    } else if (models.size() == 1) {
        model = Iterables.getOnlyElement(models);
    } else {
        model = MiningModelUtil.createModelChain(models);
    }
    if ((model != null) && (postProcessorNames.size() > 0)) {
        org.dmg.pmml.Model finalModel = MiningModelUtil.getFinalModel(model);
        Output output = ModelUtil.ensureOutput(finalModel);
        for (FieldName postProcessorName : postProcessorNames) {
            DerivedField derivedField = derivedFields.get(postProcessorName);
            encoder.removeDerivedField(postProcessorName);
            OutputField outputField = new OutputField(derivedField.getName(), derivedField.getOpType(), derivedField.getDataType()).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression(derivedField.getExpression());
            output.addOutputFields(outputField);
        }
    }
    PMML pmml = encoder.encodePMML(model);
    if ((model != null) && (predictionColumns.size() > 0 || probabilityColumns.size() > 0) && (verification != null)) {
        Dataset<Row> dataset = verification.getDataset();
        Dataset<Row> transformedDataset = verification.getTransformedDataset();
        Double precision = verification.getPrecision();
        Double zeroThreshold = verification.getZeroThreshold();
        List<String> inputColumns = new ArrayList<>();
        MiningSchema miningSchema = model.getMiningSchema();
        List<MiningField> miningFields = miningSchema.getMiningFields();
        for (MiningField miningField : miningFields) {
            MiningField.UsageType usageType = miningField.getUsageType();
            switch(usageType) {
                case ACTIVE:
                    FieldName name = miningField.getName();
                    inputColumns.add(name.getValue());
                    break;
                default:
                    break;
            }
        }
        Map<VerificationField, List<?>> data = new LinkedHashMap<>();
        for (String inputColumn : inputColumns) {
            VerificationField verificationField = ModelUtil.createVerificationField(FieldName.create(inputColumn));
            data.put(verificationField, getColumn(dataset, inputColumn));
        }
        for (String predictionColumn : predictionColumns) {
            Feature feature = encoder.getOnlyFeature(predictionColumn);
            VerificationField verificationField = ModelUtil.createVerificationField(feature.getName()).setPrecision(precision).setZeroThreshold(zeroThreshold);
            data.put(verificationField, getColumn(transformedDataset, predictionColumn));
        }
        for (String probabilityColumn : probabilityColumns) {
            List<Feature> features = encoder.getFeatures(probabilityColumn);
            for (int i = 0; i < features.size(); i++) {
                Feature feature = features.get(i);
                VerificationField verificationField = ModelUtil.createVerificationField(feature.getName()).setPrecision(precision).setZeroThreshold(zeroThreshold);
                data.put(verificationField, getVectorColumn(transformedDataset, probabilityColumn, i));
            }
        }
        model.setModelVerification(ModelUtil.createModelVerification(data));
    }
    return pmml;
}
Also used : MiningField(org.dmg.pmml.MiningField) StructType(org.apache.spark.sql.types.StructType) HasProbabilityCol(org.apache.spark.ml.param.shared.HasProbabilityCol) GeneralizedLinearRegressionModel(org.apache.spark.ml.regression.GeneralizedLinearRegressionModel) ArrayList(java.util.ArrayList) ResultFeature(org.dmg.pmml.ResultFeature) Feature(org.jpmml.converter.Feature) LinkedHashMap(java.util.LinkedHashMap) HasFeatureImportances(org.jpmml.sparkml.model.HasFeatureImportances) ArrayList(java.util.ArrayList) List(java.util.List) HasPredictionCol(org.apache.spark.ml.param.shared.HasPredictionCol) MiningSchema(org.dmg.pmml.MiningSchema) OutputField(org.dmg.pmml.OutputField) Row(org.apache.spark.sql.Row) Transformer(org.apache.spark.ml.Transformer) PipelineModel(org.apache.spark.ml.PipelineModel) Output(org.dmg.pmml.Output) FieldName(org.dmg.pmml.FieldName) VerificationField(org.dmg.pmml.VerificationField) GeneralizedLinearRegressionModel(org.apache.spark.ml.regression.GeneralizedLinearRegressionModel) PipelineModel(org.apache.spark.ml.PipelineModel) TrainValidationSplitModel(org.apache.spark.ml.tuning.TrainValidationSplitModel) CrossValidatorModel(org.apache.spark.ml.tuning.CrossValidatorModel) PMML(org.dmg.pmml.PMML) DerivedField(org.dmg.pmml.DerivedField)

Aggregations

VerificationField (org.dmg.pmml.VerificationField)3 LinkedHashMap (java.util.LinkedHashMap)2 List (java.util.List)2 PMML (org.dmg.pmml.PMML)2 Function (com.google.common.base.Function)1 ArrayList (java.util.ArrayList)1 Collection (java.util.Collection)1 Iterator (java.util.Iterator)1 PipelineModel (org.apache.spark.ml.PipelineModel)1 Transformer (org.apache.spark.ml.Transformer)1 HasPredictionCol (org.apache.spark.ml.param.shared.HasPredictionCol)1 HasProbabilityCol (org.apache.spark.ml.param.shared.HasProbabilityCol)1 GeneralizedLinearRegressionModel (org.apache.spark.ml.regression.GeneralizedLinearRegressionModel)1 CrossValidatorModel (org.apache.spark.ml.tuning.CrossValidatorModel)1 TrainValidationSplitModel (org.apache.spark.ml.tuning.TrainValidationSplitModel)1 Row (org.apache.spark.sql.Row)1 StructType (org.apache.spark.sql.types.StructType)1 DerivedField (org.dmg.pmml.DerivedField)1 FieldName (org.dmg.pmml.FieldName)1 MiningField (org.dmg.pmml.MiningField)1