use of org.dmg.pmml.MiningFunction in project jpmml-sparkml by jpmml.
the class GeneralizedLinearRegressionModelConverter method encodeModel.
@Override
public GeneralRegressionModel encodeModel(Schema schema) {
GeneralizedLinearRegressionModel model = getTransformer();
String targetCategory = null;
MiningFunction miningFunction = getMiningFunction();
switch(miningFunction) {
case CLASSIFICATION:
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
if (categoricalLabel.size() != 2) {
throw new IllegalArgumentException();
}
targetCategory = categoricalLabel.getValue(1);
break;
default:
break;
}
GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), null, null, null).setDistribution(parseFamily(model.getFamily())).setLinkFunction(parseLinkFunction(model.getLink())).setLinkParameter(parseLinkParameter(model.getLink()));
GeneralRegressionModelUtil.encodeRegressionTable(generalRegressionModel, schema.getFeatures(), model.intercept(), VectorUtil.toList(model.coefficients()), targetCategory);
return generalRegressionModel;
}
Aggregations