Search in sources :

Example 1 with Schema

use of org.jpmml.converter.Schema in project jpmml-r by jpmml.

the class GBMConverter method encodeModel.

@Override
public MiningModel encodeModel(Schema schema) {
    RGenericVector gbm = getObject();
    RDoubleVector initF = (RDoubleVector) gbm.getValue("initF");
    RGenericVector trees = (RGenericVector) gbm.getValue("trees");
    RGenericVector c_splits = (RGenericVector) gbm.getValue("c.splits");
    RGenericVector distribution = (RGenericVector) gbm.getValue("distribution");
    RStringVector distributionName = (RStringVector) distribution.getValue("name");
    Schema segmentSchema = new Schema(new ContinuousLabel(null, DataType.DOUBLE), schema.getFeatures());
    List<TreeModel> treeModels = new ArrayList<>();
    for (int i = 0; i < trees.size(); i++) {
        RGenericVector tree = (RGenericVector) trees.getValue(i);
        TreeModel treeModel = encodeTreeModel(MiningFunction.REGRESSION, tree, c_splits, segmentSchema);
        treeModels.add(treeModel);
    }
    MiningModel miningModel = encodeMiningModel(distributionName, treeModels, initF.asScalar(), schema);
    return miningModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel) Schema(org.jpmml.converter.Schema) ArrayList(java.util.ArrayList) ContinuousLabel(org.jpmml.converter.ContinuousLabel)

Example 2 with Schema

use of org.jpmml.converter.Schema in project jpmml-r by jpmml.

the class RangerConverter method encodeForest.

private List<TreeModel> encodeForest(RGenericVector forest, MiningFunction miningFunction, ScoreEncoder scoreEncoder, Schema schema) {
    RNumberVector<?> numTrees = (RNumberVector<?>) forest.getValue("num.trees");
    RGenericVector childNodeIDs = (RGenericVector) forest.getValue("child.nodeIDs");
    RGenericVector splitVarIDs = (RGenericVector) forest.getValue("split.varIDs");
    RGenericVector splitValues = (RGenericVector) forest.getValue("split.values");
    RGenericVector terminalClassCounts = (RGenericVector) forest.getValue("terminal.class.counts", true);
    Schema segmentSchema = schema.toAnonymousSchema();
    List<TreeModel> treeModels = new ArrayList<>();
    for (int i = 0; i < ValueUtil.asInt(numTrees.asScalar()); i++) {
        TreeModel treeModel = encodeTreeModel(miningFunction, scoreEncoder, (RGenericVector) childNodeIDs.getValue(i), (RNumberVector<?>) splitVarIDs.getValue(i), (RNumberVector<?>) splitValues.getValue(i), (terminalClassCounts != null ? (RGenericVector) terminalClassCounts.getValue(i) : null), segmentSchema);
        treeModels.add(treeModel);
    }
    return treeModels;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) Schema(org.jpmml.converter.Schema) ArrayList(java.util.ArrayList)

Example 3 with Schema

use of org.jpmml.converter.Schema in project jpmml-r by jpmml.

the class PreProcessEncoder method createSchema.

@Override
public Schema createSchema() {
    Schema schema = super.createSchema();
    schema = filter(schema);
    return schema;
}
Also used : Schema(org.jpmml.converter.Schema)

Example 4 with Schema

use of org.jpmml.converter.Schema in project jpmml-sparkml by jpmml.

the class ModelConverter method encodeSchema.

public Schema encodeSchema(SparkMLEncoder encoder) {
    T model = getTransformer();
    Label label = null;
    if (model instanceof HasLabelCol) {
        HasLabelCol hasLabelCol = (HasLabelCol) model;
        String labelCol = hasLabelCol.getLabelCol();
        Feature feature = encoder.getOnlyFeature(labelCol);
        MiningFunction miningFunction = getMiningFunction();
        switch(miningFunction) {
            case CLASSIFICATION:
                {
                    if (feature instanceof CategoricalFeature) {
                        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
                        DataField dataField = encoder.getDataField(categoricalFeature.getName());
                        label = new CategoricalLabel(dataField);
                    } else if (feature instanceof ContinuousFeature) {
                        ContinuousFeature continuousFeature = (ContinuousFeature) feature;
                        int numClasses = 2;
                        if (model instanceof ClassificationModel) {
                            ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
                            numClasses = classificationModel.numClasses();
                        }
                        List<String> categories = new ArrayList<>();
                        for (int i = 0; i < numClasses; i++) {
                            categories.add(String.valueOf(i));
                        }
                        Field<?> field = encoder.toCategorical(continuousFeature.getName(), categories);
                        encoder.putOnlyFeature(labelCol, new CategoricalFeature(encoder, field, categories));
                        label = new CategoricalLabel(field.getName(), field.getDataType(), categories);
                    } else {
                        throw new IllegalArgumentException("Expected a categorical or categorical-like continuous feature, got " + feature);
                    }
                }
                break;
            case REGRESSION:
                {
                    Field<?> field = encoder.toContinuous(feature.getName());
                    field.setDataType(DataType.DOUBLE);
                    label = new ContinuousLabel(field.getName(), field.getDataType());
                }
                break;
            default:
                throw new IllegalArgumentException("Mining function " + miningFunction + " is not supported");
        }
    }
    if (model instanceof ClassificationModel) {
        ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
        CategoricalLabel categoricalLabel = (CategoricalLabel) label;
        int numClasses = classificationModel.numClasses();
        if (numClasses != categoricalLabel.size()) {
            throw new IllegalArgumentException("Expected " + numClasses + " target categories, got " + categoricalLabel.size() + " target categories");
        }
    }
    String featuresCol = model.getFeaturesCol();
    List<Feature> features = encoder.getFeatures(featuresCol);
    if (model instanceof PredictionModel) {
        PredictionModel<?, ?> predictionModel = (PredictionModel<?, ?>) model;
        int numFeatures = predictionModel.numFeatures();
        if (numFeatures != -1 && features.size() != numFeatures) {
            throw new IllegalArgumentException("Expected " + numFeatures + " features, got " + features.size() + " features");
        }
    }
    Schema result = new Schema(label, features);
    return result;
}
Also used : Schema(org.jpmml.converter.Schema) ContinuousLabel(org.jpmml.converter.ContinuousLabel) CategoricalLabel(org.jpmml.converter.CategoricalLabel) Label(org.jpmml.converter.Label) ArrayList(java.util.ArrayList) PredictionModel(org.apache.spark.ml.PredictionModel) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) HasLabelCol(org.apache.spark.ml.param.shared.HasLabelCol) OutputField(org.dmg.pmml.OutputField) Field(org.dmg.pmml.Field) DataField(org.dmg.pmml.DataField) ContinuousFeature(org.jpmml.converter.ContinuousFeature) DataField(org.dmg.pmml.DataField) CategoricalLabel(org.jpmml.converter.CategoricalLabel) MiningFunction(org.dmg.pmml.MiningFunction) ClassificationModel(org.apache.spark.ml.classification.ClassificationModel) ContinuousLabel(org.jpmml.converter.ContinuousLabel)

Example 5 with Schema

use of org.jpmml.converter.Schema in project jpmml-r by jpmml.

the class GBMConverter method encodeBinaryClassification.

private MiningModel encodeBinaryClassification(List<TreeModel> treeModels, Double initF, double coefficient, Schema schema) {
    Schema segmentSchema = new Schema(new ContinuousLabel(null, DataType.DOUBLE), schema.getFeatures());
    MiningModel miningModel = createMiningModel(treeModels, initF, segmentSchema).setOutput(ModelUtil.createPredictedOutput(FieldName.create("gbmValue"), OpType.CONTINUOUS, DataType.DOUBLE));
    return MiningModelUtil.createBinaryLogisticClassification(miningModel, -coefficient, 0d, RegressionModel.NormalizationMethod.LOGIT, true, schema);
}
Also used : MiningModel(org.dmg.pmml.mining.MiningModel) Schema(org.jpmml.converter.Schema) ContinuousLabel(org.jpmml.converter.ContinuousLabel)

Aggregations

Schema (org.jpmml.converter.Schema)15 ArrayList (java.util.ArrayList)9 MiningModel (org.dmg.pmml.mining.MiningModel)9 TreeModel (org.dmg.pmml.tree.TreeModel)8 ContinuousLabel (org.jpmml.converter.ContinuousLabel)6 CategoricalLabel (org.jpmml.converter.CategoricalLabel)4 Model (org.dmg.pmml.Model)2 OutputField (org.dmg.pmml.OutputField)2 PMML (org.dmg.pmml.PMML)2 Feature (org.jpmml.converter.Feature)2 Label (org.jpmml.converter.Label)2 List (java.util.List)1 PipelineModel (org.apache.spark.ml.PipelineModel)1 PredictionModel (org.apache.spark.ml.PredictionModel)1 Transformer (org.apache.spark.ml.Transformer)1 ClassificationModel (org.apache.spark.ml.classification.ClassificationModel)1 GBTClassificationModel (org.apache.spark.ml.classification.GBTClassificationModel)1 HasLabelCol (org.apache.spark.ml.param.shared.HasLabelCol)1 DecisionTreeModel (org.apache.spark.ml.tree.DecisionTreeModel)1 CrossValidatorModel (org.apache.spark.ml.tuning.CrossValidatorModel)1