Search in sources :

Example 1 with LogisticRegressionModel

use of org.apache.spark.ml.classification.LogisticRegressionModel in project jpmml-sparkml by jpmml.

the class LogisticRegressionModelConverter method encodeModel.

@Override
public RegressionModel encodeModel(Schema schema) {
    LogisticRegressionModel model = getTransformer();
    CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
    if (categoricalLabel.size() == 2) {
        RegressionModel regressionModel = RegressionModelUtil.createBinaryLogisticClassification(schema.getFeatures(), VectorUtil.toList(model.coefficients()), model.intercept(), RegressionModel.NormalizationMethod.LOGIT, true, schema).setOutput(null);
        return regressionModel;
    } else if (categoricalLabel.size() > 2) {
        Matrix coefficientMatrix = model.coefficientMatrix();
        Vector interceptVector = model.interceptVector();
        List<? extends Feature> features = schema.getFeatures();
        List<RegressionTable> regressionTables = new ArrayList<>();
        for (int i = 0; i < categoricalLabel.size(); i++) {
            RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(features, MatrixUtil.getRow(coefficientMatrix, i), interceptVector.apply(i)).setTargetCategory(categoricalLabel.getValue(i));
            regressionTables.add(regressionTable);
        }
        RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), regressionTables).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX);
        return regressionModel;
    } else {
        throw new IllegalArgumentException();
    }
}
Also used : Matrix(org.apache.spark.ml.linalg.Matrix) CategoricalLabel(org.jpmml.converter.CategoricalLabel) LogisticRegressionModel(org.apache.spark.ml.classification.LogisticRegressionModel) ArrayList(java.util.ArrayList) List(java.util.List) Vector(org.apache.spark.ml.linalg.Vector) Feature(org.jpmml.converter.Feature) RegressionTable(org.dmg.pmml.regression.RegressionTable) LogisticRegressionModel(org.apache.spark.ml.classification.LogisticRegressionModel) RegressionModel(org.dmg.pmml.regression.RegressionModel)

Aggregations

ArrayList (java.util.ArrayList)1 List (java.util.List)1 LogisticRegressionModel (org.apache.spark.ml.classification.LogisticRegressionModel)1 Matrix (org.apache.spark.ml.linalg.Matrix)1 Vector (org.apache.spark.ml.linalg.Vector)1 RegressionModel (org.dmg.pmml.regression.RegressionModel)1 RegressionTable (org.dmg.pmml.regression.RegressionTable)1 CategoricalLabel (org.jpmml.converter.CategoricalLabel)1 Feature (org.jpmml.converter.Feature)1