Search in sources :

Example 11 with MiningModel

use of org.dmg.pmml.mining.MiningModel in project jpmml-r by jpmml.

the class RandomForestConverter method encodeClassification.

private MiningModel encodeClassification(RGenericVector forest, final Schema schema) {
    RNumberVector<?> bestvar = (RNumberVector<?>) forest.getValue("bestvar");
    RNumberVector<?> treemap = (RNumberVector<?>) forest.getValue("treemap");
    RIntegerVector nodepred = (RIntegerVector) forest.getValue("nodepred");
    RDoubleVector xbestsplit = (RDoubleVector) forest.getValue("xbestsplit");
    RIntegerVector nrnodes = (RIntegerVector) forest.getValue("nrnodes");
    RDoubleVector ntree = (RDoubleVector) forest.getValue("ntree");
    int rows = nrnodes.asScalar();
    int columns = ValueUtil.asInt(ntree.asScalar());
    final CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
    ScoreEncoder<Integer> scoreEncoder = new ScoreEncoder<Integer>() {

        @Override
        public String encode(Integer value) {
            return categoricalLabel.getValue(value - 1);
        }
    };
    Schema segmentSchema = schema.toAnonymousSchema();
    List<TreeModel> treeModels = new ArrayList<>();
    for (int i = 0; i < columns; i++) {
        List<? extends Number> daughters = FortranMatrixUtil.getColumn(treemap.getValues(), 2 * rows, columns, i);
        TreeModel treeModel = encodeTreeModel(MiningFunction.CLASSIFICATION, scoreEncoder, FortranMatrixUtil.getColumn(daughters, rows, 2, 0), FortranMatrixUtil.getColumn(daughters, rows, 2, 1), FortranMatrixUtil.getColumn(nodepred.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(bestvar.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(xbestsplit.getValues(), rows, columns, i), segmentSchema);
        treeModels.add(treeModel);
    }
    MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel)).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.MAJORITY_VOTE, treeModels)).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
    return miningModel;
}
Also used : Schema(org.jpmml.converter.Schema) ArrayList(java.util.ArrayList) TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel) CategoricalLabel(org.jpmml.converter.CategoricalLabel)

Example 12 with MiningModel

use of org.dmg.pmml.mining.MiningModel in project jpmml-r by jpmml.

the class RandomForestConverter method encodeRegression.

private MiningModel encodeRegression(RGenericVector forest, final Schema schema) {
    RNumberVector<?> leftDaughter = (RNumberVector<?>) forest.getValue("leftDaughter");
    RNumberVector<?> rightDaughter = (RNumberVector<?>) forest.getValue("rightDaughter");
    RDoubleVector nodepred = (RDoubleVector) forest.getValue("nodepred");
    RNumberVector<?> bestvar = (RNumberVector<?>) forest.getValue("bestvar");
    RDoubleVector xbestsplit = (RDoubleVector) forest.getValue("xbestsplit");
    RIntegerVector nrnodes = (RIntegerVector) forest.getValue("nrnodes");
    RDoubleVector ntree = (RDoubleVector) forest.getValue("ntree");
    ScoreEncoder<Double> scoreEncoder = new ScoreEncoder<Double>() {

        @Override
        public String encode(Double value) {
            return ValueUtil.formatValue(value);
        }
    };
    int rows = nrnodes.asScalar();
    int columns = ValueUtil.asInt(ntree.asScalar());
    Schema segmentSchema = schema.toAnonymousSchema();
    List<TreeModel> treeModels = new ArrayList<>();
    for (int i = 0; i < columns; i++) {
        TreeModel treeModel = encodeTreeModel(MiningFunction.REGRESSION, scoreEncoder, FortranMatrixUtil.getColumn(leftDaughter.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(rightDaughter.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(nodepred.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(bestvar.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(xbestsplit.getValues(), rows, columns, i), segmentSchema);
        treeModels.add(treeModel);
    }
    MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels));
    return miningModel;
}
Also used : Schema(org.jpmml.converter.Schema) ArrayList(java.util.ArrayList) TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel)

Example 13 with MiningModel

use of org.dmg.pmml.mining.MiningModel in project jpmml-sparkml by jpmml.

the class RandomForestClassificationModelConverter method encodeModel.

@Override
public MiningModel encodeModel(Schema schema) {
    RandomForestClassificationModel model = getTransformer();
    List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(model, schema);
    MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels));
    return miningModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel) RandomForestClassificationModel(org.apache.spark.ml.classification.RandomForestClassificationModel)

Example 14 with MiningModel

use of org.dmg.pmml.mining.MiningModel in project jpmml-sparkml by jpmml.

the class RandomForestRegressionModelConverter method encodeModel.

@Override
public MiningModel encodeModel(Schema schema) {
    RandomForestRegressionModel model = getTransformer();
    List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(model, schema);
    MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels));
    return miningModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) RandomForestRegressionModel(org.apache.spark.ml.regression.RandomForestRegressionModel) MiningModel(org.dmg.pmml.mining.MiningModel)

Example 15 with MiningModel

use of org.dmg.pmml.mining.MiningModel in project jpmml-sparkml by jpmml.

the class ConverterUtil method toPMML.

public static PMML toPMML(StructType schema, PipelineModel pipelineModel) {
    checkVersion();
    SparkMLEncoder encoder = new SparkMLEncoder(schema);
    List<org.dmg.pmml.Model> models = new ArrayList<>();
    Iterable<Transformer> transformers = getTransformers(pipelineModel);
    for (Transformer transformer : transformers) {
        TransformerConverter<?> converter = ConverterUtil.createConverter(transformer);
        if (converter instanceof FeatureConverter) {
            FeatureConverter<?> featureConverter = (FeatureConverter<?>) converter;
            featureConverter.registerFeatures(encoder);
        } else if (converter instanceof ModelConverter) {
            ModelConverter<?> modelConverter = (ModelConverter<?>) converter;
            org.dmg.pmml.Model model = modelConverter.registerModel(encoder);
            models.add(model);
        } else {
            throw new IllegalArgumentException("Expected a " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + " instance, got " + converter);
        }
    }
    org.dmg.pmml.Model rootModel;
    if (models.size() == 1) {
        rootModel = Iterables.getOnlyElement(models);
    } else if (models.size() > 1) {
        List<MiningField> targetMiningFields = new ArrayList<>();
        for (org.dmg.pmml.Model model : models) {
            MiningSchema miningSchema = model.getMiningSchema();
            List<MiningField> miningFields = miningSchema.getMiningFields();
            for (MiningField miningField : miningFields) {
                MiningField.UsageType usageType = miningField.getUsageType();
                switch(usageType) {
                    case PREDICTED:
                    case TARGET:
                        targetMiningFields.add(miningField);
                        break;
                    default:
                        break;
                }
            }
        }
        MiningSchema miningSchema = new MiningSchema(targetMiningFields);
        MiningModel miningModel = MiningModelUtil.createModelChain(models, new Schema(null, Collections.<Feature>emptyList())).setMiningSchema(miningSchema);
        rootModel = miningModel;
    } else {
        throw new IllegalArgumentException("Expected a pipeline with one or more models, got a pipeline with zero models");
    }
    PMML pmml = encoder.encodePMML(rootModel);
    return pmml;
}
Also used : MiningField(org.dmg.pmml.MiningField) Transformer(org.apache.spark.ml.Transformer) MiningSchema(org.dmg.pmml.MiningSchema) Schema(org.jpmml.converter.Schema) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) MiningSchema(org.dmg.pmml.MiningSchema) MiningModel(org.dmg.pmml.mining.MiningModel) MiningModel(org.dmg.pmml.mining.MiningModel) PipelineModel(org.apache.spark.ml.PipelineModel) TrainValidationSplitModel(org.apache.spark.ml.tuning.TrainValidationSplitModel) CrossValidatorModel(org.apache.spark.ml.tuning.CrossValidatorModel) PMML(org.dmg.pmml.PMML) ArrayList(java.util.ArrayList) List(java.util.List)

Aggregations

MiningModel (org.dmg.pmml.mining.MiningModel)17 TreeModel (org.dmg.pmml.tree.TreeModel)12 Schema (org.jpmml.converter.Schema)9 ArrayList (java.util.ArrayList)6 ContinuousLabel (org.jpmml.converter.ContinuousLabel)5 Node (org.dmg.pmml.tree.Node)3 CategoricalLabel (org.jpmml.converter.CategoricalLabel)3 List (java.util.List)1 PipelineModel (org.apache.spark.ml.PipelineModel)1 Transformer (org.apache.spark.ml.Transformer)1 GBTClassificationModel (org.apache.spark.ml.classification.GBTClassificationModel)1 RandomForestClassificationModel (org.apache.spark.ml.classification.RandomForestClassificationModel)1 GBTRegressionModel (org.apache.spark.ml.regression.GBTRegressionModel)1 RandomForestRegressionModel (org.apache.spark.ml.regression.RandomForestRegressionModel)1 CrossValidatorModel (org.apache.spark.ml.tuning.CrossValidatorModel)1 TrainValidationSplitModel (org.apache.spark.ml.tuning.TrainValidationSplitModel)1 FieldName (org.dmg.pmml.FieldName)1 FieldRef (org.dmg.pmml.FieldRef)1 MiningField (org.dmg.pmml.MiningField)1 MiningSchema (org.dmg.pmml.MiningSchema)1