Search in sources :

Example 11 with Schema

use of org.jpmml.converter.Schema in project jpmml-r by jpmml.

the class ModelConverter method encodePMML.

public PMML encodePMML(RExpEncoder encoder) {
    encodeSchema(encoder);
    Schema schema = encoder.createSchema();
    Model model = encodeModel(schema);
    PMML pmml = encoder.encodePMML(model);
    return pmml;
}
Also used : Schema(org.jpmml.converter.Schema) Model(org.dmg.pmml.Model) PMML(org.dmg.pmml.PMML)

Example 12 with Schema

use of org.jpmml.converter.Schema in project jpmml-sparkml by jpmml.

the class TreeModelUtil method encodeDecisionTreeEnsemble.

public static <M extends Model<M> & TreeEnsembleModel<T>, T extends Model<T> & DecisionTreeModel> List<TreeModel> encodeDecisionTreeEnsemble(M model, PredicateManager predicateManager, Schema schema) {
    Schema segmentSchema = schema.toAnonymousSchema();
    List<TreeModel> treeModels = new ArrayList<>();
    T[] trees = model.trees();
    for (T tree : trees) {
        TreeModel treeModel = encodeDecisionTree(tree, predicateManager, segmentSchema);
        treeModels.add(treeModel);
    }
    return treeModels;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) DecisionTreeModel(org.apache.spark.ml.tree.DecisionTreeModel) Schema(org.jpmml.converter.Schema) ArrayList(java.util.ArrayList)

Example 13 with Schema

use of org.jpmml.converter.Schema in project jpmml-sparkml by jpmml.

the class ConverterUtil method toPMML.

public static PMML toPMML(StructType schema, PipelineModel pipelineModel) {
    checkVersion();
    SparkMLEncoder encoder = new SparkMLEncoder(schema);
    List<org.dmg.pmml.Model> models = new ArrayList<>();
    Iterable<Transformer> transformers = getTransformers(pipelineModel);
    for (Transformer transformer : transformers) {
        TransformerConverter<?> converter = ConverterUtil.createConverter(transformer);
        if (converter instanceof FeatureConverter) {
            FeatureConverter<?> featureConverter = (FeatureConverter<?>) converter;
            featureConverter.registerFeatures(encoder);
        } else if (converter instanceof ModelConverter) {
            ModelConverter<?> modelConverter = (ModelConverter<?>) converter;
            org.dmg.pmml.Model model = modelConverter.registerModel(encoder);
            models.add(model);
        } else {
            throw new IllegalArgumentException("Expected a " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + " instance, got " + converter);
        }
    }
    org.dmg.pmml.Model rootModel;
    if (models.size() == 1) {
        rootModel = Iterables.getOnlyElement(models);
    } else if (models.size() > 1) {
        List<MiningField> targetMiningFields = new ArrayList<>();
        for (org.dmg.pmml.Model model : models) {
            MiningSchema miningSchema = model.getMiningSchema();
            List<MiningField> miningFields = miningSchema.getMiningFields();
            for (MiningField miningField : miningFields) {
                MiningField.UsageType usageType = miningField.getUsageType();
                switch(usageType) {
                    case PREDICTED:
                    case TARGET:
                        targetMiningFields.add(miningField);
                        break;
                    default:
                        break;
                }
            }
        }
        MiningSchema miningSchema = new MiningSchema(targetMiningFields);
        MiningModel miningModel = MiningModelUtil.createModelChain(models, new Schema(null, Collections.<Feature>emptyList())).setMiningSchema(miningSchema);
        rootModel = miningModel;
    } else {
        throw new IllegalArgumentException("Expected a pipeline with one or more models, got a pipeline with zero models");
    }
    PMML pmml = encoder.encodePMML(rootModel);
    return pmml;
}
Also used : MiningField(org.dmg.pmml.MiningField) Transformer(org.apache.spark.ml.Transformer) MiningSchema(org.dmg.pmml.MiningSchema) Schema(org.jpmml.converter.Schema) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) MiningSchema(org.dmg.pmml.MiningSchema) MiningModel(org.dmg.pmml.mining.MiningModel) MiningModel(org.dmg.pmml.mining.MiningModel) PipelineModel(org.apache.spark.ml.PipelineModel) TrainValidationSplitModel(org.apache.spark.ml.tuning.TrainValidationSplitModel) CrossValidatorModel(org.apache.spark.ml.tuning.CrossValidatorModel) PMML(org.dmg.pmml.PMML) ArrayList(java.util.ArrayList) List(java.util.List)

Example 14 with Schema

use of org.jpmml.converter.Schema in project jpmml-sparkml by jpmml.

the class ModelConverter method registerModel.

public org.dmg.pmml.Model registerModel(SparkMLEncoder encoder) {
    Schema schema = encodeSchema(encoder);
    Label label = schema.getLabel();
    org.dmg.pmml.Model model = encodeModel(schema);
    List<OutputField> sparkOutputFields = registerOutputFields(label, encoder);
    if (sparkOutputFields != null && sparkOutputFields.size() > 0) {
        org.dmg.pmml.Model lastModel = getLastModel(model);
        Output output = lastModel.getOutput();
        if (output == null) {
            output = new Output();
            lastModel.setOutput(output);
        }
        List<OutputField> outputFields = output.getOutputFields();
        outputFields.addAll(0, sparkOutputFields);
    }
    return model;
}
Also used : Schema(org.jpmml.converter.Schema) Output(org.dmg.pmml.Output) ContinuousLabel(org.jpmml.converter.ContinuousLabel) CategoricalLabel(org.jpmml.converter.CategoricalLabel) Label(org.jpmml.converter.Label) OutputField(org.dmg.pmml.OutputField)

Example 15 with Schema

use of org.jpmml.converter.Schema in project jpmml-sparkml by jpmml.

the class GBTClassificationModelConverter method encodeModel.

@Override
public MiningModel encodeModel(Schema schema) {
    GBTClassificationModel model = getTransformer();
    String lossType = model.getLossType();
    switch(lossType) {
        case "logistic":
            break;
        default:
            throw new IllegalArgumentException("Loss function " + lossType + " is not supported");
    }
    Schema segmentSchema = new Schema(new ContinuousLabel(null, DataType.DOUBLE), schema.getFeatures());
    List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(model, segmentSchema);
    MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(segmentSchema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.WEIGHTED_SUM, treeModels, Doubles.asList(model.treeWeights()))).setOutput(ModelUtil.createPredictedOutput(FieldName.create("gbtValue"), OpType.CONTINUOUS, DataType.DOUBLE));
    return MiningModelUtil.createBinaryLogisticClassification(miningModel, 2d, 0d, RegressionModel.NormalizationMethod.LOGIT, false, schema);
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel) GBTClassificationModel(org.apache.spark.ml.classification.GBTClassificationModel) Schema(org.jpmml.converter.Schema) ContinuousLabel(org.jpmml.converter.ContinuousLabel)

Aggregations

Schema (org.jpmml.converter.Schema)15 ArrayList (java.util.ArrayList)9 MiningModel (org.dmg.pmml.mining.MiningModel)9 TreeModel (org.dmg.pmml.tree.TreeModel)8 ContinuousLabel (org.jpmml.converter.ContinuousLabel)6 CategoricalLabel (org.jpmml.converter.CategoricalLabel)4 Model (org.dmg.pmml.Model)2 OutputField (org.dmg.pmml.OutputField)2 PMML (org.dmg.pmml.PMML)2 Feature (org.jpmml.converter.Feature)2 Label (org.jpmml.converter.Label)2 List (java.util.List)1 PipelineModel (org.apache.spark.ml.PipelineModel)1 PredictionModel (org.apache.spark.ml.PredictionModel)1 Transformer (org.apache.spark.ml.Transformer)1 ClassificationModel (org.apache.spark.ml.classification.ClassificationModel)1 GBTClassificationModel (org.apache.spark.ml.classification.GBTClassificationModel)1 HasLabelCol (org.apache.spark.ml.param.shared.HasLabelCol)1 DecisionTreeModel (org.apache.spark.ml.tree.DecisionTreeModel)1 CrossValidatorModel (org.apache.spark.ml.tuning.CrossValidatorModel)1