Search in sources :

Example 1 with MiningModel

use of org.dmg.pmml.mining.MiningModel in project jpmml-r by jpmml.

the class GBMConverter method encodeModel.

@Override
public MiningModel encodeModel(Schema schema) {
    RGenericVector gbm = getObject();
    RDoubleVector initF = (RDoubleVector) gbm.getValue("initF");
    RGenericVector trees = (RGenericVector) gbm.getValue("trees");
    RGenericVector c_splits = (RGenericVector) gbm.getValue("c.splits");
    RGenericVector distribution = (RGenericVector) gbm.getValue("distribution");
    RStringVector distributionName = (RStringVector) distribution.getValue("name");
    Schema segmentSchema = new Schema(new ContinuousLabel(null, DataType.DOUBLE), schema.getFeatures());
    List<TreeModel> treeModels = new ArrayList<>();
    for (int i = 0; i < trees.size(); i++) {
        RGenericVector tree = (RGenericVector) trees.getValue(i);
        TreeModel treeModel = encodeTreeModel(MiningFunction.REGRESSION, tree, c_splits, segmentSchema);
        treeModels.add(treeModel);
    }
    MiningModel miningModel = encodeMiningModel(distributionName, treeModels, initF.asScalar(), schema);
    return miningModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel) Schema(org.jpmml.converter.Schema) ArrayList(java.util.ArrayList) ContinuousLabel(org.jpmml.converter.ContinuousLabel)

Example 2 with MiningModel

use of org.dmg.pmml.mining.MiningModel in project jpmml-r by jpmml.

the class RangerConverter method encodeClassification.

private MiningModel encodeClassification(RGenericVector ranger, Schema schema) {
    RGenericVector forest = (RGenericVector) ranger.getValue("forest");
    final RStringVector levels = (RStringVector) forest.getValue("levels");
    ScoreEncoder scoreEncoder = new ScoreEncoder() {

        @Override
        public void encode(Node node, Number splitValue, RNumberVector<?> terminalClassCount) {
            int index = ValueUtil.asInt(splitValue);
            if (terminalClassCount != null) {
                throw new IllegalArgumentException();
            }
            node.setScore(levels.getValue(index - 1));
        }
    };
    List<TreeModel> treeModels = encodeForest(forest, MiningFunction.CLASSIFICATION, scoreEncoder, schema);
    MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.MAJORITY_VOTE, treeModels));
    return miningModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel) Node(org.dmg.pmml.tree.Node)

Example 3 with MiningModel

use of org.dmg.pmml.mining.MiningModel in project jpmml-sparkml by jpmml.

the class GBTRegressionModelConverter method encodeModel.

@Override
public MiningModel encodeModel(Schema schema) {
    GBTRegressionModel model = getTransformer();
    List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(model, schema);
    MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.WEIGHTED_SUM, treeModels, Doubles.asList(model.treeWeights())));
    return miningModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel) GBTRegressionModel(org.apache.spark.ml.regression.GBTRegressionModel)

Example 4 with MiningModel

use of org.dmg.pmml.mining.MiningModel in project jpmml-r by jpmml.

the class GBMConverter method encodeBinaryClassification.

private MiningModel encodeBinaryClassification(List<TreeModel> treeModels, Double initF, double coefficient, Schema schema) {
    Schema segmentSchema = new Schema(new ContinuousLabel(null, DataType.DOUBLE), schema.getFeatures());
    MiningModel miningModel = createMiningModel(treeModels, initF, segmentSchema).setOutput(ModelUtil.createPredictedOutput(FieldName.create("gbmValue"), OpType.CONTINUOUS, DataType.DOUBLE));
    return MiningModelUtil.createBinaryLogisticClassification(miningModel, -coefficient, 0d, RegressionModel.NormalizationMethod.LOGIT, true, schema);
}
Also used : MiningModel(org.dmg.pmml.mining.MiningModel) Schema(org.jpmml.converter.Schema) ContinuousLabel(org.jpmml.converter.ContinuousLabel)

Example 5 with MiningModel

use of org.dmg.pmml.mining.MiningModel in project jpmml-r by jpmml.

the class GBMConverter method encodeMultinomialClassification.

private MiningModel encodeMultinomialClassification(List<TreeModel> treeModels, Double initF, Schema schema) {
    CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
    Schema segmentSchema = new Schema(new ContinuousLabel(null, DataType.DOUBLE), schema.getFeatures());
    List<Model> miningModels = new ArrayList<>();
    for (int i = 0, columns = categoricalLabel.size(), rows = (treeModels.size() / columns); i < columns; i++) {
        MiningModel miningModel = createMiningModel(CMatrixUtil.getColumn(treeModels, rows, columns, i), initF, segmentSchema).setOutput(ModelUtil.createPredictedOutput(FieldName.create("gbmValue(" + categoricalLabel.getValue(i) + ")"), OpType.CONTINUOUS, DataType.DOUBLE));
        miningModels.add(miningModel);
    }
    return MiningModelUtil.createClassification(miningModels, RegressionModel.NormalizationMethod.SOFTMAX, true, schema);
}
Also used : MiningModel(org.dmg.pmml.mining.MiningModel) CategoricalLabel(org.jpmml.converter.CategoricalLabel) Schema(org.jpmml.converter.Schema) Model(org.dmg.pmml.Model) MiningModel(org.dmg.pmml.mining.MiningModel) RegressionModel(org.dmg.pmml.regression.RegressionModel) TreeModel(org.dmg.pmml.tree.TreeModel) ArrayList(java.util.ArrayList) ContinuousLabel(org.jpmml.converter.ContinuousLabel)

Aggregations

MiningModel (org.dmg.pmml.mining.MiningModel)17 TreeModel (org.dmg.pmml.tree.TreeModel)12 Schema (org.jpmml.converter.Schema)9 ArrayList (java.util.ArrayList)6 ContinuousLabel (org.jpmml.converter.ContinuousLabel)5 Node (org.dmg.pmml.tree.Node)3 CategoricalLabel (org.jpmml.converter.CategoricalLabel)3 List (java.util.List)1 PipelineModel (org.apache.spark.ml.PipelineModel)1 Transformer (org.apache.spark.ml.Transformer)1 GBTClassificationModel (org.apache.spark.ml.classification.GBTClassificationModel)1 RandomForestClassificationModel (org.apache.spark.ml.classification.RandomForestClassificationModel)1 GBTRegressionModel (org.apache.spark.ml.regression.GBTRegressionModel)1 RandomForestRegressionModel (org.apache.spark.ml.regression.RandomForestRegressionModel)1 CrossValidatorModel (org.apache.spark.ml.tuning.CrossValidatorModel)1 TrainValidationSplitModel (org.apache.spark.ml.tuning.TrainValidationSplitModel)1 FieldName (org.dmg.pmml.FieldName)1 FieldRef (org.dmg.pmml.FieldRef)1 MiningField (org.dmg.pmml.MiningField)1 MiningSchema (org.dmg.pmml.MiningSchema)1