Search in sources :

Example 1 with MiningFunction

use of org.dmg.pmml.MiningFunction in project jpmml-r by jpmml.

the class GLMConverter method encodeModel.

@Override
public Model encodeModel(Schema schema) {
    RGenericVector glm = getObject();
    RDoubleVector coefficients = (RDoubleVector) glm.getValue("coefficients");
    RGenericVector family = (RGenericVector) glm.getValue("family");
    Double intercept = coefficients.getValue(getInterceptName(), true);
    RStringVector familyFamily = (RStringVector) family.getValue("family");
    RStringVector familyLink = (RStringVector) family.getValue("link");
    Label label = schema.getLabel();
    List<? extends Feature> features = schema.getFeatures();
    if (coefficients.size() != (features.size() + (intercept != null ? 1 : 0))) {
        throw new IllegalArgumentException();
    }
    List<Double> featureCoefficients = getFeatureCoefficients(features, coefficients);
    MiningFunction miningFunction = getMiningFunction(familyFamily.asScalar());
    String targetCategory = null;
    switch(miningFunction) {
        case CLASSIFICATION:
            {
                CategoricalLabel categoricalLabel = (CategoricalLabel) label;
                if (categoricalLabel.size() != 2) {
                    throw new IllegalArgumentException();
                }
                targetCategory = categoricalLabel.getValue(1);
            }
            break;
        default:
            break;
    }
    GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, miningFunction, ModelUtil.createMiningSchema(label), null, null, null).setDistribution(parseFamily(familyFamily.asScalar())).setLinkFunction(parseLinkFunction(familyLink.asScalar())).setLinkParameter(parseLinkParameter(familyLink.asScalar()));
    GeneralRegressionModelUtil.encodeRegressionTable(generalRegressionModel, features, intercept, featureCoefficients, targetCategory);
    switch(miningFunction) {
        case CLASSIFICATION:
            generalRegressionModel.setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, (CategoricalLabel) label));
            break;
        default:
            break;
    }
    return generalRegressionModel;
}
Also used : CategoricalLabel(org.jpmml.converter.CategoricalLabel) GeneralRegressionModel(org.dmg.pmml.general_regression.GeneralRegressionModel) CategoricalLabel(org.jpmml.converter.CategoricalLabel) Label(org.jpmml.converter.Label) MiningFunction(org.dmg.pmml.MiningFunction)

Example 2 with MiningFunction

use of org.dmg.pmml.MiningFunction in project jpmml-sparkml by jpmml.

the class TreeModelCompactor method visit.

@Override
public VisitorAction visit(TreeModel treeModel) {
    TreeModel.MissingValueStrategy missingValueStrategy = treeModel.getMissingValueStrategy();
    TreeModel.NoTrueChildStrategy noTrueChildStrategy = treeModel.getNoTrueChildStrategy();
    TreeModel.SplitCharacteristic splitCharacteristic = treeModel.getSplitCharacteristic();
    if (!(TreeModel.MissingValueStrategy.NONE).equals(missingValueStrategy) || !(TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION).equals(noTrueChildStrategy) || !(TreeModel.SplitCharacteristic.BINARY_SPLIT).equals(splitCharacteristic)) {
        throw new IllegalArgumentException();
    }
    treeModel.setMissingValueStrategy(TreeModel.MissingValueStrategy.NULL_PREDICTION).setSplitCharacteristic(TreeModel.SplitCharacteristic.MULTI_SPLIT);
    MiningFunction miningFunction = treeModel.getMiningFunction();
    switch(miningFunction) {
        case REGRESSION:
            treeModel.setNoTrueChildStrategy(TreeModel.NoTrueChildStrategy.RETURN_LAST_PREDICTION);
            break;
        case CLASSIFICATION:
            break;
        default:
            throw new IllegalArgumentException();
    }
    return super.visit(treeModel);
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) MiningFunction(org.dmg.pmml.MiningFunction)

Example 3 with MiningFunction

use of org.dmg.pmml.MiningFunction in project jpmml-sparkml by jpmml.

the class ModelConverter method encodeSchema.

public Schema encodeSchema(SparkMLEncoder encoder) {
    T model = getTransformer();
    Label label = null;
    if (model instanceof HasLabelCol) {
        HasLabelCol hasLabelCol = (HasLabelCol) model;
        String labelCol = hasLabelCol.getLabelCol();
        Feature feature = encoder.getOnlyFeature(labelCol);
        MiningFunction miningFunction = getMiningFunction();
        switch(miningFunction) {
            case CLASSIFICATION:
                {
                    if (feature instanceof CategoricalFeature) {
                        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
                        DataField dataField = encoder.getDataField(categoricalFeature.getName());
                        label = new CategoricalLabel(dataField);
                    } else if (feature instanceof ContinuousFeature) {
                        ContinuousFeature continuousFeature = (ContinuousFeature) feature;
                        int numClasses = 2;
                        if (model instanceof ClassificationModel) {
                            ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
                            numClasses = classificationModel.numClasses();
                        }
                        List<String> categories = new ArrayList<>();
                        for (int i = 0; i < numClasses; i++) {
                            categories.add(String.valueOf(i));
                        }
                        Field<?> field = encoder.toCategorical(continuousFeature.getName(), categories);
                        encoder.putOnlyFeature(labelCol, new CategoricalFeature(encoder, field, categories));
                        label = new CategoricalLabel(field.getName(), field.getDataType(), categories);
                    } else {
                        throw new IllegalArgumentException("Expected a categorical or categorical-like continuous feature, got " + feature);
                    }
                }
                break;
            case REGRESSION:
                {
                    Field<?> field = encoder.toContinuous(feature.getName());
                    field.setDataType(DataType.DOUBLE);
                    label = new ContinuousLabel(field.getName(), field.getDataType());
                }
                break;
            default:
                throw new IllegalArgumentException("Mining function " + miningFunction + " is not supported");
        }
    }
    if (model instanceof ClassificationModel) {
        ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
        CategoricalLabel categoricalLabel = (CategoricalLabel) label;
        int numClasses = classificationModel.numClasses();
        if (numClasses != categoricalLabel.size()) {
            throw new IllegalArgumentException("Expected " + numClasses + " target categories, got " + categoricalLabel.size() + " target categories");
        }
    }
    String featuresCol = model.getFeaturesCol();
    List<Feature> features = encoder.getFeatures(featuresCol);
    if (model instanceof PredictionModel) {
        PredictionModel<?, ?> predictionModel = (PredictionModel<?, ?>) model;
        int numFeatures = predictionModel.numFeatures();
        if (numFeatures != -1 && features.size() != numFeatures) {
            throw new IllegalArgumentException("Expected " + numFeatures + " features, got " + features.size() + " features");
        }
    }
    Schema result = new Schema(label, features);
    return result;
}
Also used : Schema(org.jpmml.converter.Schema) ContinuousLabel(org.jpmml.converter.ContinuousLabel) CategoricalLabel(org.jpmml.converter.CategoricalLabel) Label(org.jpmml.converter.Label) ArrayList(java.util.ArrayList) PredictionModel(org.apache.spark.ml.PredictionModel) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) HasLabelCol(org.apache.spark.ml.param.shared.HasLabelCol) OutputField(org.dmg.pmml.OutputField) Field(org.dmg.pmml.Field) DataField(org.dmg.pmml.DataField) ContinuousFeature(org.jpmml.converter.ContinuousFeature) DataField(org.dmg.pmml.DataField) CategoricalLabel(org.jpmml.converter.CategoricalLabel) MiningFunction(org.dmg.pmml.MiningFunction) ClassificationModel(org.apache.spark.ml.classification.ClassificationModel) ContinuousLabel(org.jpmml.converter.ContinuousLabel)

Example 4 with MiningFunction

use of org.dmg.pmml.MiningFunction in project jpmml-r by jpmml.

the class GLMConverter method encodeSchema.

@Override
public void encodeSchema(RExpEncoder encoder) {
    RGenericVector glm = getObject();
    RGenericVector family = (RGenericVector) glm.getValue("family");
    RGenericVector model = (RGenericVector) glm.getValue("model");
    RStringVector familyFamily = (RStringVector) family.getValue("family");
    super.encodeSchema(encoder);
    MiningFunction miningFunction = getMiningFunction(familyFamily.asScalar());
    switch(miningFunction) {
        case CLASSIFICATION:
            Label label = encoder.getLabel();
            RIntegerVector variable = (RIntegerVector) model.getValue((label.getName()).getValue());
            DataField dataField = (DataField) encoder.toCategorical(label.getName(), RExpUtil.getFactorLevels(variable));
            encoder.setLabel(dataField);
            break;
        default:
            break;
    }
}
Also used : DataField(org.dmg.pmml.DataField) CategoricalLabel(org.jpmml.converter.CategoricalLabel) Label(org.jpmml.converter.Label) MiningFunction(org.dmg.pmml.MiningFunction)

Example 5 with MiningFunction

use of org.dmg.pmml.MiningFunction in project jpmml-sparkml by jpmml.

the class GeneralizedLinearRegressionModelConverter method registerOutputFields.

@Override
public List<OutputField> registerOutputFields(Label label, SparkMLEncoder encoder) {
    List<OutputField> result = super.registerOutputFields(label, encoder);
    MiningFunction miningFunction = getMiningFunction();
    switch(miningFunction) {
        case CLASSIFICATION:
            CategoricalLabel categoricalLabel = (CategoricalLabel) label;
            result = new ArrayList<>(result);
            result.addAll(ModelUtil.createProbabilityFields(DataType.DOUBLE, categoricalLabel.getValues()));
            break;
        default:
            break;
    }
    return result;
}
Also used : CategoricalLabel(org.jpmml.converter.CategoricalLabel) OutputField(org.dmg.pmml.OutputField) MiningFunction(org.dmg.pmml.MiningFunction)

Aggregations

MiningFunction (org.dmg.pmml.MiningFunction)6 CategoricalLabel (org.jpmml.converter.CategoricalLabel)5 Label (org.jpmml.converter.Label)3 DataField (org.dmg.pmml.DataField)2 OutputField (org.dmg.pmml.OutputField)2 GeneralRegressionModel (org.dmg.pmml.general_regression.GeneralRegressionModel)2 ArrayList (java.util.ArrayList)1 PredictionModel (org.apache.spark.ml.PredictionModel)1 ClassificationModel (org.apache.spark.ml.classification.ClassificationModel)1 HasLabelCol (org.apache.spark.ml.param.shared.HasLabelCol)1 GeneralizedLinearRegressionModel (org.apache.spark.ml.regression.GeneralizedLinearRegressionModel)1 Field (org.dmg.pmml.Field)1 TreeModel (org.dmg.pmml.tree.TreeModel)1 CategoricalFeature (org.jpmml.converter.CategoricalFeature)1 ContinuousFeature (org.jpmml.converter.ContinuousFeature)1 ContinuousLabel (org.jpmml.converter.ContinuousLabel)1 Feature (org.jpmml.converter.Feature)1 Schema (org.jpmml.converter.Schema)1