use of org.apache.spark.ml.regression.GeneralizedLinearRegressionModel in project jpmml-sparkml by jpmml.
the class GeneralizedLinearRegressionModelConverter method getMiningFunction.
@Override
public MiningFunction getMiningFunction() {
GeneralizedLinearRegressionModel model = getTransformer();
String family = model.getFamily();
switch(family) {
case "binomial":
return MiningFunction.CLASSIFICATION;
default:
return MiningFunction.REGRESSION;
}
}
use of org.apache.spark.ml.regression.GeneralizedLinearRegressionModel in project jpmml-sparkml by jpmml.
the class GeneralizedLinearRegressionModelConverter method encodeModel.
@Override
public GeneralRegressionModel encodeModel(Schema schema) {
GeneralizedLinearRegressionModel model = getTransformer();
String targetCategory = null;
MiningFunction miningFunction = getMiningFunction();
switch(miningFunction) {
case CLASSIFICATION:
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
if (categoricalLabel.size() != 2) {
throw new IllegalArgumentException();
}
targetCategory = categoricalLabel.getValue(1);
break;
default:
break;
}
GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), null, null, null).setDistribution(parseFamily(model.getFamily())).setLinkFunction(parseLinkFunction(model.getLink())).setLinkParameter(parseLinkParameter(model.getLink()));
GeneralRegressionModelUtil.encodeRegressionTable(generalRegressionModel, schema.getFeatures(), model.intercept(), VectorUtil.toList(model.coefficients()), targetCategory);
return generalRegressionModel;
}
Aggregations