use of org.dmg.pmml.general_regression.GeneralRegressionModel in project jpmml-r by jpmml.
the class GLMConverter method encodeModel.
@Override
public Model encodeModel(Schema schema) {
RGenericVector glm = getObject();
RDoubleVector coefficients = (RDoubleVector) glm.getValue("coefficients");
RGenericVector family = (RGenericVector) glm.getValue("family");
Double intercept = coefficients.getValue(getInterceptName(), true);
RStringVector familyFamily = (RStringVector) family.getValue("family");
RStringVector familyLink = (RStringVector) family.getValue("link");
Label label = schema.getLabel();
List<? extends Feature> features = schema.getFeatures();
if (coefficients.size() != (features.size() + (intercept != null ? 1 : 0))) {
throw new IllegalArgumentException();
}
List<Double> featureCoefficients = getFeatureCoefficients(features, coefficients);
MiningFunction miningFunction = getMiningFunction(familyFamily.asScalar());
String targetCategory = null;
switch(miningFunction) {
case CLASSIFICATION:
{
CategoricalLabel categoricalLabel = (CategoricalLabel) label;
if (categoricalLabel.size() != 2) {
throw new IllegalArgumentException();
}
targetCategory = categoricalLabel.getValue(1);
}
break;
default:
break;
}
GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, miningFunction, ModelUtil.createMiningSchema(label), null, null, null).setDistribution(parseFamily(familyFamily.asScalar())).setLinkFunction(parseLinkFunction(familyLink.asScalar())).setLinkParameter(parseLinkParameter(familyLink.asScalar()));
GeneralRegressionModelUtil.encodeRegressionTable(generalRegressionModel, features, intercept, featureCoefficients, targetCategory);
switch(miningFunction) {
case CLASSIFICATION:
generalRegressionModel.setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, (CategoricalLabel) label));
break;
default:
break;
}
return generalRegressionModel;
}
use of org.dmg.pmml.general_regression.GeneralRegressionModel in project jpmml-r by jpmml.
the class EarthConverter method encodeModel.
@Override
public GeneralRegressionModel encodeModel(Schema schema) {
RGenericVector earth = getObject();
RDoubleVector coefficients = (RDoubleVector) earth.getValue("coefficients");
Double intercept = coefficients.getValue(0);
List<? extends Feature> features = schema.getFeatures();
if (coefficients.size() != (features.size() + 1)) {
throw new IllegalArgumentException();
}
List<Double> featureCoefficients = (coefficients.getValues()).subList(1, features.size() + 1);
GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), null, null, null).setLinkFunction(GeneralRegressionModel.LinkFunction.IDENTITY);
GeneralRegressionModelUtil.encodeRegressionTable(generalRegressionModel, features, intercept, featureCoefficients, null);
return generalRegressionModel;
}
use of org.dmg.pmml.general_regression.GeneralRegressionModel in project jpmml-r by jpmml.
the class LRMConverter method encodeModel.
@Override
public Model encodeModel(Schema schema) {
RGenericVector lrm = getObject();
RDoubleVector coefficients = (RDoubleVector) lrm.getValue("coefficients");
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
if (categoricalLabel.size() != 2) {
throw new IllegalArgumentException();
}
String targetCategory = categoricalLabel.getValue(1);
Double intercept = coefficients.getValue(getInterceptName(), true);
List<? extends Feature> features = schema.getFeatures();
if (coefficients.size() != (features.size() + (intercept != null ? 1 : 0))) {
throw new IllegalArgumentException();
}
List<Double> featureCoefficients = getFeatureCoefficients(features, coefficients);
GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), null, null, null).setLinkFunction(GeneralRegressionModel.LinkFunction.LOGIT).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
GeneralRegressionModelUtil.encodeRegressionTable(generalRegressionModel, features, intercept, featureCoefficients, targetCategory);
return generalRegressionModel;
}
use of org.dmg.pmml.general_regression.GeneralRegressionModel in project jpmml-r by jpmml.
the class MVRConverter method encodeModel.
@Override
public GeneralRegressionModel encodeModel(Schema schema) {
RGenericVector mvr = getObject();
RDoubleVector coefficients = (RDoubleVector) mvr.getValue("coefficients");
RDoubleVector xMeans = (RDoubleVector) mvr.getValue("Xmeans");
RDoubleVector yMeans = (RDoubleVector) mvr.getValue("Ymeans");
RNumberVector<?> ncomp = (RNumberVector<?>) mvr.getValue("ncomp");
RStringVector rowNames = coefficients.dimnames(0);
RStringVector columnNames = coefficients.dimnames(1);
RStringVector compNames = coefficients.dimnames(2);
int rows = rowNames.size();
int columns = columnNames.size();
int components = compNames.size();
List<? extends Feature> features = schema.getFeatures();
List<Double> featureCoefficients = FortranMatrixUtil.getColumn(coefficients.getValues(), rows, (columns * components), 0 + (ValueUtil.asInt(ncomp.asScalar()) - 1));
Double intercept = yMeans.getValue(0);
for (int j = 0; j < rowNames.size(); j++) {
intercept -= (featureCoefficients.get(j) * xMeans.getValue(j));
}
GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), null, null, null).setLinkFunction(GeneralRegressionModel.LinkFunction.IDENTITY);
GeneralRegressionModelUtil.encodeRegressionTable(generalRegressionModel, features, intercept, featureCoefficients, null);
return generalRegressionModel;
}
use of org.dmg.pmml.general_regression.GeneralRegressionModel in project jpmml-sparkml by jpmml.
the class GeneralizedLinearRegressionModelConverter method encodeModel.
@Override
public GeneralRegressionModel encodeModel(Schema schema) {
GeneralizedLinearRegressionModel model = getTransformer();
String targetCategory = null;
MiningFunction miningFunction = getMiningFunction();
switch(miningFunction) {
case CLASSIFICATION:
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
if (categoricalLabel.size() != 2) {
throw new IllegalArgumentException();
}
targetCategory = categoricalLabel.getValue(1);
break;
default:
break;
}
GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), null, null, null).setDistribution(parseFamily(model.getFamily())).setLinkFunction(parseLinkFunction(model.getLink())).setLinkParameter(parseLinkParameter(model.getLink()));
GeneralRegressionModelUtil.encodeRegressionTable(generalRegressionModel, schema.getFeatures(), model.intercept(), VectorUtil.toList(model.coefficients()), targetCategory);
return generalRegressionModel;
}
Aggregations