Search in sources :

Example 16 with MiningModel

use of org.dmg.pmml.mining.MiningModel in project jpmml-sparkml by jpmml.

the class ModelConverter method getLastModel.

protected org.dmg.pmml.Model getLastModel(org.dmg.pmml.Model model) {
    if (model instanceof MiningModel) {
        MiningModel miningModel = (MiningModel) model;
        Segmentation segmentation = miningModel.getSegmentation();
        MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod();
        switch(multipleModelMethod) {
            case MODEL_CHAIN:
                List<Segment> segments = segmentation.getSegments();
                if (segments.size() > 0) {
                    Segment lastSegment = segments.get(segments.size() - 1);
                    return lastSegment.getModel();
                }
                break;
            default:
                break;
        }
    }
    return model;
}
Also used : MiningModel(org.dmg.pmml.mining.MiningModel) Segmentation(org.dmg.pmml.mining.Segmentation) MultipleModelMethod(org.dmg.pmml.mining.Segmentation.MultipleModelMethod) Segment(org.dmg.pmml.mining.Segment)

Example 17 with MiningModel

use of org.dmg.pmml.mining.MiningModel in project jpmml-sparkml by jpmml.

the class GBTClassificationModelConverter method encodeModel.

@Override
public MiningModel encodeModel(Schema schema) {
    GBTClassificationModel model = getTransformer();
    String lossType = model.getLossType();
    switch(lossType) {
        case "logistic":
            break;
        default:
            throw new IllegalArgumentException("Loss function " + lossType + " is not supported");
    }
    Schema segmentSchema = new Schema(new ContinuousLabel(null, DataType.DOUBLE), schema.getFeatures());
    List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(model, segmentSchema);
    MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(segmentSchema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.WEIGHTED_SUM, treeModels, Doubles.asList(model.treeWeights()))).setOutput(ModelUtil.createPredictedOutput(FieldName.create("gbtValue"), OpType.CONTINUOUS, DataType.DOUBLE));
    return MiningModelUtil.createBinaryLogisticClassification(miningModel, 2d, 0d, RegressionModel.NormalizationMethod.LOGIT, false, schema);
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel) GBTClassificationModel(org.apache.spark.ml.classification.GBTClassificationModel) Schema(org.jpmml.converter.Schema) ContinuousLabel(org.jpmml.converter.ContinuousLabel)

Aggregations

MiningModel (org.dmg.pmml.mining.MiningModel)17 TreeModel (org.dmg.pmml.tree.TreeModel)12 Schema (org.jpmml.converter.Schema)9 ArrayList (java.util.ArrayList)6 ContinuousLabel (org.jpmml.converter.ContinuousLabel)5 Node (org.dmg.pmml.tree.Node)3 CategoricalLabel (org.jpmml.converter.CategoricalLabel)3 List (java.util.List)1 PipelineModel (org.apache.spark.ml.PipelineModel)1 Transformer (org.apache.spark.ml.Transformer)1 GBTClassificationModel (org.apache.spark.ml.classification.GBTClassificationModel)1 RandomForestClassificationModel (org.apache.spark.ml.classification.RandomForestClassificationModel)1 GBTRegressionModel (org.apache.spark.ml.regression.GBTRegressionModel)1 RandomForestRegressionModel (org.apache.spark.ml.regression.RandomForestRegressionModel)1 CrossValidatorModel (org.apache.spark.ml.tuning.CrossValidatorModel)1 TrainValidationSplitModel (org.apache.spark.ml.tuning.TrainValidationSplitModel)1 FieldName (org.dmg.pmml.FieldName)1 FieldRef (org.dmg.pmml.FieldRef)1 MiningField (org.dmg.pmml.MiningField)1 MiningSchema (org.dmg.pmml.MiningSchema)1