Search in sources :

Example 1 with RegressionTable

use of org.dmg.pmml.regression.RegressionTable in project jpmml-sparkml by jpmml.

the class LogisticRegressionModelConverter method encodeModel.

@Override
public RegressionModel encodeModel(Schema schema) {
    LogisticRegressionModel model = getTransformer();
    CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
    if (categoricalLabel.size() == 2) {
        RegressionModel regressionModel = RegressionModelUtil.createBinaryLogisticClassification(schema.getFeatures(), VectorUtil.toList(model.coefficients()), model.intercept(), RegressionModel.NormalizationMethod.LOGIT, true, schema).setOutput(null);
        return regressionModel;
    } else if (categoricalLabel.size() > 2) {
        Matrix coefficientMatrix = model.coefficientMatrix();
        Vector interceptVector = model.interceptVector();
        List<? extends Feature> features = schema.getFeatures();
        List<RegressionTable> regressionTables = new ArrayList<>();
        for (int i = 0; i < categoricalLabel.size(); i++) {
            RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(features, MatrixUtil.getRow(coefficientMatrix, i), interceptVector.apply(i)).setTargetCategory(categoricalLabel.getValue(i));
            regressionTables.add(regressionTable);
        }
        RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), regressionTables).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX);
        return regressionModel;
    } else {
        throw new IllegalArgumentException();
    }
}
Also used : Matrix(org.apache.spark.ml.linalg.Matrix) CategoricalLabel(org.jpmml.converter.CategoricalLabel) LogisticRegressionModel(org.apache.spark.ml.classification.LogisticRegressionModel) ArrayList(java.util.ArrayList) List(java.util.List) Vector(org.apache.spark.ml.linalg.Vector) Feature(org.jpmml.converter.Feature) RegressionTable(org.dmg.pmml.regression.RegressionTable) LogisticRegressionModel(org.apache.spark.ml.classification.LogisticRegressionModel) RegressionModel(org.dmg.pmml.regression.RegressionModel)

Example 2 with RegressionTable

use of org.dmg.pmml.regression.RegressionTable in project pyramid by cheng-li.

the class MiningModelUtil method createClassification.

public static MiningModel createClassification(List<? extends Model> models, RegressionModel.NormalizationMethod normalizationMethod, boolean hasProbabilityDistribution, Schema schema) {
    CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
    // modified here
    if (categoricalLabel.size() != models.size()) {
        throw new IllegalArgumentException();
    }
    if (normalizationMethod != null) {
        switch(normalizationMethod) {
            case NONE:
            case SIMPLEMAX:
            case SOFTMAX:
                break;
            default:
                throw new IllegalArgumentException();
        }
    }
    MathContext mathContext = null;
    List<RegressionTable> regressionTables = new ArrayList<>();
    for (int i = 0; i < categoricalLabel.size(); i++) {
        Model model = models.get(i);
        MathContext modelMathContext = model.getMathContext();
        if (modelMathContext == null) {
            modelMathContext = MathContext.DOUBLE;
        }
        if (mathContext == null) {
            mathContext = modelMathContext;
        } else {
            if (!Objects.equals(mathContext, modelMathContext)) {
                throw new IllegalArgumentException();
            }
        }
        Feature feature = MODEL_PREDICTION.apply(model);
        RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(Collections.singletonList(feature), Collections.singletonList(1d), null).setTargetCategory(categoricalLabel.getValue(i));
        regressionTables.add(regressionTable);
    }
    RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), regressionTables).setNormalizationMethod(normalizationMethod).setMathContext(ModelUtil.simplifyMathContext(mathContext)).setOutput(hasProbabilityDistribution ? ModelUtil.createProbabilityOutput(mathContext, categoricalLabel) : null);
    List<Model> segmentationModels = new ArrayList<>(models);
    segmentationModels.add(regressionModel);
    return createModelChain(segmentationModels, schema);
}
Also used : CategoricalLabel(org.jpmml.converter.CategoricalLabel) ArrayList(java.util.ArrayList) Model(org.dmg.pmml.Model) MiningModel(org.dmg.pmml.mining.MiningModel) RegressionModel(org.dmg.pmml.regression.RegressionModel) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) MathContext(org.dmg.pmml.MathContext) RegressionTable(org.dmg.pmml.regression.RegressionTable) RegressionModel(org.dmg.pmml.regression.RegressionModel)

Example 3 with RegressionTable

use of org.dmg.pmml.regression.RegressionTable in project shifu by ShifuML.

the class PMMLAdapterCommonUtil method getRegressionTable.

/**
 * Generate Regression Table based on the weight list, intercept and partial
 * PMML model
 *
 * @param weights
 *            weight list for the Regression Table
 * @param intercept
 *            the intercept
 * @param pmmlModel
 *            partial PMMl model
 * @return regression model instance
 */
public static RegressionModel getRegressionTable(final double[] weights, final double intercept, RegressionModel pmmlModel) {
    RegressionTable table = new RegressionTable();
    MiningSchema schema = pmmlModel.getMiningSchema();
    // TODO may not need target field in LRModel
    pmmlModel.setMiningFunction(MiningFunction.REGRESSION);
    pmmlModel.setNormalizationMethod(NormalizationMethod.LOGIT);
    List<String> outputFields = getSchemaFieldViaUsageType(schema, UsageType.TARGET);
    // TODO only one outputField, what if we have more than one outputField
    pmmlModel.setTargetFieldName(new FieldName(outputFields.get(0)));
    table.setTargetCategory(outputFields.get(0));
    List<String> activeFields = getSchemaFieldViaUsageType(schema, UsageType.ACTIVE);
    int index = 0;
    for (DerivedField dField : pmmlModel.getLocalTransformations().getDerivedFields()) {
        Expression expression = dField.getExpression();
        if (expression instanceof NormContinuous) {
            NormContinuous norm = (NormContinuous) expression;
            if (activeFields.contains(norm.getField().getValue()))
                table.addNumericPredictors(new NumericPredictor(dField.getName(), weights[index++]));
        }
    }
    pmmlModel.addRegressionTables(table);
    return pmmlModel;
}
Also used : NormContinuous(org.dmg.pmml.NormContinuous) MiningSchema(org.dmg.pmml.MiningSchema) Expression(org.dmg.pmml.Expression) FieldName(org.dmg.pmml.FieldName) DerivedField(org.dmg.pmml.DerivedField) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) RegressionTable(org.dmg.pmml.regression.RegressionTable)

Example 4 with RegressionTable

use of org.dmg.pmml.regression.RegressionTable in project shifu by ShifuML.

the class PMMLLRModelBuilder method adaptMLModelToPMML.

public RegressionModel adaptMLModelToPMML(ml.shifu.shifu.core.LR lr, RegressionModel pmmlModel) {
    pmmlModel.setNormalizationMethod(NormalizationMethod.LOGIT);
    pmmlModel.setMiningFunction(MiningFunction.REGRESSION);
    RegressionTable table = new RegressionTable();
    table.setIntercept(lr.getBias());
    LocalTransformations lt = pmmlModel.getLocalTransformations();
    List<DerivedField> df = lt.getDerivedFields();
    HashMap<FieldName, FieldName> miningTransformMap = new HashMap<FieldName, FieldName>();
    for (DerivedField dField : df) {
        // Apply z-scale normalization on numerical variables
        if (dField.getExpression() instanceof NormContinuous) {
            miningTransformMap.put(((NormContinuous) dField.getExpression()).getField(), dField.getName());
        } else // Apply bin map on categorical variables
        if (dField.getExpression() instanceof MapValues) {
            miningTransformMap.put(((MapValues) dField.getExpression()).getFieldColumnPairs().get(0).getField(), dField.getName());
        } else if (dField.getExpression() instanceof Discretize) {
            miningTransformMap.put(((Discretize) dField.getExpression()).getField(), dField.getName());
        }
    }
    List<MiningField> miningList = pmmlModel.getMiningSchema().getMiningFields();
    int index = 0;
    for (int i = 0; i < miningList.size(); i++) {
        MiningField mField = miningList.get(i);
        if (mField.getUsageType() != UsageType.ACTIVE)
            continue;
        FieldName mFieldName = mField.getName();
        FieldName fName = mFieldName;
        while (miningTransformMap.containsKey(fName)) {
            fName = miningTransformMap.get(fName);
        }
        NumericPredictor np = new NumericPredictor();
        np.setName(fName);
        np.setCoefficient(lr.getWeights()[index++]);
        table.addNumericPredictors(np);
    }
    pmmlModel.addRegressionTables(table);
    return pmmlModel;
}
Also used : NormContinuous(org.dmg.pmml.NormContinuous) MiningField(org.dmg.pmml.MiningField) HashMap(java.util.HashMap) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) RegressionTable(org.dmg.pmml.regression.RegressionTable) LocalTransformations(org.dmg.pmml.LocalTransformations) MapValues(org.dmg.pmml.MapValues) Discretize(org.dmg.pmml.Discretize) DerivedField(org.dmg.pmml.DerivedField) FieldName(org.dmg.pmml.FieldName)

Aggregations

RegressionTable (org.dmg.pmml.regression.RegressionTable)4 ArrayList (java.util.ArrayList)2 DerivedField (org.dmg.pmml.DerivedField)2 FieldName (org.dmg.pmml.FieldName)2 NormContinuous (org.dmg.pmml.NormContinuous)2 NumericPredictor (org.dmg.pmml.regression.NumericPredictor)2 RegressionModel (org.dmg.pmml.regression.RegressionModel)2 CategoricalLabel (org.jpmml.converter.CategoricalLabel)2 Feature (org.jpmml.converter.Feature)2 HashMap (java.util.HashMap)1 List (java.util.List)1 LogisticRegressionModel (org.apache.spark.ml.classification.LogisticRegressionModel)1 Matrix (org.apache.spark.ml.linalg.Matrix)1 Vector (org.apache.spark.ml.linalg.Vector)1 Discretize (org.dmg.pmml.Discretize)1 Expression (org.dmg.pmml.Expression)1 LocalTransformations (org.dmg.pmml.LocalTransformations)1 MapValues (org.dmg.pmml.MapValues)1 MathContext (org.dmg.pmml.MathContext)1 MiningField (org.dmg.pmml.MiningField)1