use of org.apache.spark.ml.classification.LogisticRegressionModel in project jpmml-sparkml by jpmml.
the class LogisticRegressionModelConverter method encodeModel.
@Override
public RegressionModel encodeModel(Schema schema) {
LogisticRegressionModel model = getTransformer();
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
if (categoricalLabel.size() == 2) {
RegressionModel regressionModel = RegressionModelUtil.createBinaryLogisticClassification(schema.getFeatures(), VectorUtil.toList(model.coefficients()), model.intercept(), RegressionModel.NormalizationMethod.LOGIT, true, schema).setOutput(null);
return regressionModel;
} else if (categoricalLabel.size() > 2) {
Matrix coefficientMatrix = model.coefficientMatrix();
Vector interceptVector = model.interceptVector();
List<? extends Feature> features = schema.getFeatures();
List<RegressionTable> regressionTables = new ArrayList<>();
for (int i = 0; i < categoricalLabel.size(); i++) {
RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(features, MatrixUtil.getRow(coefficientMatrix, i), interceptVector.apply(i)).setTargetCategory(categoricalLabel.getValue(i));
regressionTables.add(regressionTable);
}
RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), regressionTables).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX);
return regressionModel;
} else {
throw new IllegalArgumentException();
}
}
Aggregations