Search in sources :

Example 1 with GBTClassificationModel

use of org.apache.spark.ml.classification.GBTClassificationModel in project jpmml-sparkml by jpmml.

the class GBTClassificationModelConverter method encodeModel.

@Override
public MiningModel encodeModel(Schema schema) {
    GBTClassificationModel model = getTransformer();
    String lossType = model.getLossType();
    switch(lossType) {
        case "logistic":
            break;
        default:
            throw new IllegalArgumentException("Loss function " + lossType + " is not supported");
    }
    Schema segmentSchema = new Schema(new ContinuousLabel(null, DataType.DOUBLE), schema.getFeatures());
    List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(model, segmentSchema);
    MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(segmentSchema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.WEIGHTED_SUM, treeModels, Doubles.asList(model.treeWeights()))).setOutput(ModelUtil.createPredictedOutput(FieldName.create("gbtValue"), OpType.CONTINUOUS, DataType.DOUBLE));
    return MiningModelUtil.createBinaryLogisticClassification(miningModel, 2d, 0d, RegressionModel.NormalizationMethod.LOGIT, false, schema);
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel) GBTClassificationModel(org.apache.spark.ml.classification.GBTClassificationModel) Schema(org.jpmml.converter.Schema) ContinuousLabel(org.jpmml.converter.ContinuousLabel)

Aggregations

GBTClassificationModel (org.apache.spark.ml.classification.GBTClassificationModel)1 MiningModel (org.dmg.pmml.mining.MiningModel)1 TreeModel (org.dmg.pmml.tree.TreeModel)1 ContinuousLabel (org.jpmml.converter.ContinuousLabel)1 Schema (org.jpmml.converter.Schema)1