Search in sources :

Example 6 with ContinuousLabel

use of org.jpmml.converter.ContinuousLabel in project jpmml-r by jpmml.

the class GBMConverter method createMiningModel.

private static MiningModel createMiningModel(List<TreeModel> treeModels, Double initF, Schema schema) {
    ContinuousLabel continuousLabel = (ContinuousLabel) schema.getLabel();
    MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(continuousLabel)).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.SUM, treeModels)).setTargets(ModelUtil.createRescaleTargets(null, initF, continuousLabel));
    return miningModel;
}
Also used : MiningModel(org.dmg.pmml.mining.MiningModel) ContinuousLabel(org.jpmml.converter.ContinuousLabel)

Example 7 with ContinuousLabel

use of org.jpmml.converter.ContinuousLabel in project jpmml-r by jpmml.

the class SVMConverter method encodeNonFormula.

private void encodeNonFormula(RExpEncoder encoder) {
    RGenericVector svm = getObject();
    RDoubleVector type = (RDoubleVector) svm.getValue("type");
    RDoubleVector sv = (RDoubleVector) svm.getValue("SV");
    RVector<?> levels = (RVector<?>) svm.getValue("levels");
    Type svmType = Type.values()[ValueUtil.asInt(type.asScalar())];
    RStringVector rowNames = sv.dimnames(0);
    RStringVector columnNames = sv.dimnames(1);
    // Dependent variable
    {
        FieldName name = FieldName.create("_target");
        switch(svmType) {
            case C_CLASSIFICATION:
            case NU_CLASSIFICATION:
                {
                    RStringVector stringLevels = (RStringVector) levels;
                    DataField dataField = encoder.createDataField(name, OpType.CATEGORICAL, DataType.STRING, stringLevels.getValues());
                    encoder.setLabel(dataField);
                }
                break;
            case ONE_CLASSIFICATION:
                {
                    encoder.setLabel(new ContinuousLabel(null, DataType.DOUBLE));
                }
                break;
            case EPS_REGRESSION:
            case NU_REGRESSION:
                {
                    DataField dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.DOUBLE);
                    encoder.setLabel(dataField);
                }
                break;
        }
    }
    List<Feature> features = new ArrayList<>();
    // Independent variables
    for (int i = 0; i < columnNames.size(); i++) {
        String columnName = columnNames.getValue(i);
        DataField dataField = encoder.createDataField(FieldName.create(columnName), OpType.CONTINUOUS, DataType.DOUBLE);
        features.add(new ContinuousFeature(encoder, dataField));
    }
    features = scale(features, encoder);
    for (Feature feature : features) {
        encoder.addFeature(feature);
    }
}
Also used : ArrayList(java.util.ArrayList) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) OpType(org.dmg.pmml.OpType) DataType(org.dmg.pmml.DataType) ContinuousFeature(org.jpmml.converter.ContinuousFeature) DataField(org.dmg.pmml.DataField) FieldName(org.dmg.pmml.FieldName) ContinuousLabel(org.jpmml.converter.ContinuousLabel)

Example 8 with ContinuousLabel

use of org.jpmml.converter.ContinuousLabel in project jpmml-r by jpmml.

the class SVMConverter method encodeFormula.

private void encodeFormula(RExpEncoder encoder) {
    RGenericVector svm = getObject();
    RDoubleVector type = (RDoubleVector) svm.getValue("type");
    RDoubleVector sv = (RDoubleVector) svm.getValue("SV");
    RVector<?> levels = (RVector<?>) svm.getValue("levels");
    RExp terms = svm.getValue("terms");
    final RGenericVector xlevels;
    try {
        xlevels = (RGenericVector) svm.getValue("xlevels");
    } catch (IllegalArgumentException iae) {
        throw new IllegalArgumentException("No variable levels information. Please initialize the \'xlevels\' element", iae);
    }
    Type svmType = Type.values()[ValueUtil.asInt(type.asScalar())];
    RStringVector rowNames = sv.dimnames(0);
    RStringVector columnNames = sv.dimnames(1);
    RIntegerVector response = (RIntegerVector) terms.getAttributeValue("response");
    FormulaContext context = new FormulaContext() {

        @Override
        public List<String> getCategories(String variable) {
            if (xlevels.hasValue(variable)) {
                RStringVector levels = (RStringVector) xlevels.getValue(variable);
                return levels.getValues();
            }
            return null;
        }

        @Override
        public RGenericVector getData() {
            return null;
        }
    };
    Formula formula = FormulaUtil.createFormula(terms, context, encoder);
    // Dependent variable
    int responseIndex = response.asScalar();
    if (responseIndex != 0) {
        DataField dataField = (DataField) formula.getField(responseIndex - 1);
        switch(svmType) {
            case C_CLASSIFICATION:
            case NU_CLASSIFICATION:
                {
                    RStringVector stringLevels = (RStringVector) levels;
                    dataField = (DataField) encoder.toCategorical(dataField.getName(), stringLevels.getValues());
                }
                break;
            case ONE_CLASSIFICATION:
                {
                    OpType opType = dataField.getOpType();
                    if (!(OpType.CONTINUOUS).equals(opType)) {
                        throw new IllegalArgumentException();
                    }
                }
                break;
            default:
                break;
        }
        encoder.setLabel(dataField);
    } else {
        switch(svmType) {
            case ONE_CLASSIFICATION:
                break;
            default:
                throw new IllegalArgumentException();
        }
        encoder.setLabel(new ContinuousLabel(null, DataType.DOUBLE));
    }
    List<Feature> features = new ArrayList<>();
    // Independent variables
    for (int i = 0; i < columnNames.size(); i++) {
        String columnName = columnNames.getValue(i);
        Feature feature = formula.resolveFeature(columnName);
        features.add(feature);
    }
    features = scale(features, encoder);
    for (Feature feature : features) {
        encoder.addFeature(feature);
    }
}
Also used : ArrayList(java.util.ArrayList) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) OpType(org.dmg.pmml.OpType) DataType(org.dmg.pmml.DataType) DataField(org.dmg.pmml.DataField) OpType(org.dmg.pmml.OpType) ContinuousLabel(org.jpmml.converter.ContinuousLabel)

Example 9 with ContinuousLabel

use of org.jpmml.converter.ContinuousLabel in project jpmml-r by jpmml.

the class RExpEncoder method setLabel.

public void setLabel(DataField dataField) {
    Label label;
    OpType opType = dataField.getOpType();
    switch(opType) {
        case CATEGORICAL:
            label = new CategoricalLabel(dataField);
            break;
        case CONTINUOUS:
            label = new ContinuousLabel(dataField);
            break;
        default:
            throw new IllegalArgumentException();
    }
    setLabel(label);
}
Also used : CategoricalLabel(org.jpmml.converter.CategoricalLabel) CategoricalLabel(org.jpmml.converter.CategoricalLabel) ContinuousLabel(org.jpmml.converter.ContinuousLabel) Label(org.jpmml.converter.Label) OpType(org.dmg.pmml.OpType) ContinuousLabel(org.jpmml.converter.ContinuousLabel)

Example 10 with ContinuousLabel

use of org.jpmml.converter.ContinuousLabel in project jpmml-sparkml by jpmml.

the class GBTClassificationModelConverter method encodeModel.

@Override
public MiningModel encodeModel(Schema schema) {
    GBTClassificationModel model = getTransformer();
    String lossType = model.getLossType();
    switch(lossType) {
        case "logistic":
            break;
        default:
            throw new IllegalArgumentException("Loss function " + lossType + " is not supported");
    }
    Schema segmentSchema = new Schema(new ContinuousLabel(null, DataType.DOUBLE), schema.getFeatures());
    List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(model, segmentSchema);
    MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(segmentSchema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.WEIGHTED_SUM, treeModels, Doubles.asList(model.treeWeights()))).setOutput(ModelUtil.createPredictedOutput(FieldName.create("gbtValue"), OpType.CONTINUOUS, DataType.DOUBLE));
    return MiningModelUtil.createBinaryLogisticClassification(miningModel, 2d, 0d, RegressionModel.NormalizationMethod.LOGIT, false, schema);
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel) GBTClassificationModel(org.apache.spark.ml.classification.GBTClassificationModel) Schema(org.jpmml.converter.Schema) ContinuousLabel(org.jpmml.converter.ContinuousLabel)

Aggregations

ContinuousLabel (org.jpmml.converter.ContinuousLabel)10 ArrayList (java.util.ArrayList)6 MiningModel (org.dmg.pmml.mining.MiningModel)5 Schema (org.jpmml.converter.Schema)5 DataField (org.dmg.pmml.DataField)3 OpType (org.dmg.pmml.OpType)3 TreeModel (org.dmg.pmml.tree.TreeModel)3 CategoricalLabel (org.jpmml.converter.CategoricalLabel)3 ContinuousFeature (org.jpmml.converter.ContinuousFeature)3 Feature (org.jpmml.converter.Feature)3 Label (org.jpmml.converter.Label)3 DataType (org.dmg.pmml.DataType)2 PredictionModel (org.apache.spark.ml.PredictionModel)1 ClassificationModel (org.apache.spark.ml.classification.ClassificationModel)1 GBTClassificationModel (org.apache.spark.ml.classification.GBTClassificationModel)1 HasLabelCol (org.apache.spark.ml.param.shared.HasLabelCol)1 Field (org.dmg.pmml.Field)1 FieldName (org.dmg.pmml.FieldName)1 MiningFunction (org.dmg.pmml.MiningFunction)1 Model (org.dmg.pmml.Model)1