Search in sources :

Example 1 with GeneralRegressionModel

use of org.dmg.pmml.general_regression.GeneralRegressionModel in project jpmml-r by jpmml.

the class GLMConverter method encodeModel.

@Override
public Model encodeModel(Schema schema) {
    RGenericVector glm = getObject();
    RDoubleVector coefficients = (RDoubleVector) glm.getValue("coefficients");
    RGenericVector family = (RGenericVector) glm.getValue("family");
    Double intercept = coefficients.getValue(getInterceptName(), true);
    RStringVector familyFamily = (RStringVector) family.getValue("family");
    RStringVector familyLink = (RStringVector) family.getValue("link");
    Label label = schema.getLabel();
    List<? extends Feature> features = schema.getFeatures();
    if (coefficients.size() != (features.size() + (intercept != null ? 1 : 0))) {
        throw new IllegalArgumentException();
    }
    List<Double> featureCoefficients = getFeatureCoefficients(features, coefficients);
    MiningFunction miningFunction = getMiningFunction(familyFamily.asScalar());
    String targetCategory = null;
    switch(miningFunction) {
        case CLASSIFICATION:
            {
                CategoricalLabel categoricalLabel = (CategoricalLabel) label;
                if (categoricalLabel.size() != 2) {
                    throw new IllegalArgumentException();
                }
                targetCategory = categoricalLabel.getValue(1);
            }
            break;
        default:
            break;
    }
    GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, miningFunction, ModelUtil.createMiningSchema(label), null, null, null).setDistribution(parseFamily(familyFamily.asScalar())).setLinkFunction(parseLinkFunction(familyLink.asScalar())).setLinkParameter(parseLinkParameter(familyLink.asScalar()));
    GeneralRegressionModelUtil.encodeRegressionTable(generalRegressionModel, features, intercept, featureCoefficients, targetCategory);
    switch(miningFunction) {
        case CLASSIFICATION:
            generalRegressionModel.setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, (CategoricalLabel) label));
            break;
        default:
            break;
    }
    return generalRegressionModel;
}
Also used : CategoricalLabel(org.jpmml.converter.CategoricalLabel) GeneralRegressionModel(org.dmg.pmml.general_regression.GeneralRegressionModel) CategoricalLabel(org.jpmml.converter.CategoricalLabel) Label(org.jpmml.converter.Label) MiningFunction(org.dmg.pmml.MiningFunction)

Example 2 with GeneralRegressionModel

use of org.dmg.pmml.general_regression.GeneralRegressionModel in project jpmml-r by jpmml.

the class EarthConverter method encodeModel.

@Override
public GeneralRegressionModel encodeModel(Schema schema) {
    RGenericVector earth = getObject();
    RDoubleVector coefficients = (RDoubleVector) earth.getValue("coefficients");
    Double intercept = coefficients.getValue(0);
    List<? extends Feature> features = schema.getFeatures();
    if (coefficients.size() != (features.size() + 1)) {
        throw new IllegalArgumentException();
    }
    List<Double> featureCoefficients = (coefficients.getValues()).subList(1, features.size() + 1);
    GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), null, null, null).setLinkFunction(GeneralRegressionModel.LinkFunction.IDENTITY);
    GeneralRegressionModelUtil.encodeRegressionTable(generalRegressionModel, features, intercept, featureCoefficients, null);
    return generalRegressionModel;
}
Also used : GeneralRegressionModel(org.dmg.pmml.general_regression.GeneralRegressionModel)

Example 3 with GeneralRegressionModel

use of org.dmg.pmml.general_regression.GeneralRegressionModel in project jpmml-r by jpmml.

the class LRMConverter method encodeModel.

@Override
public Model encodeModel(Schema schema) {
    RGenericVector lrm = getObject();
    RDoubleVector coefficients = (RDoubleVector) lrm.getValue("coefficients");
    CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
    if (categoricalLabel.size() != 2) {
        throw new IllegalArgumentException();
    }
    String targetCategory = categoricalLabel.getValue(1);
    Double intercept = coefficients.getValue(getInterceptName(), true);
    List<? extends Feature> features = schema.getFeatures();
    if (coefficients.size() != (features.size() + (intercept != null ? 1 : 0))) {
        throw new IllegalArgumentException();
    }
    List<Double> featureCoefficients = getFeatureCoefficients(features, coefficients);
    GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), null, null, null).setLinkFunction(GeneralRegressionModel.LinkFunction.LOGIT).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
    GeneralRegressionModelUtil.encodeRegressionTable(generalRegressionModel, features, intercept, featureCoefficients, targetCategory);
    return generalRegressionModel;
}
Also used : CategoricalLabel(org.jpmml.converter.CategoricalLabel) GeneralRegressionModel(org.dmg.pmml.general_regression.GeneralRegressionModel)

Example 4 with GeneralRegressionModel

use of org.dmg.pmml.general_regression.GeneralRegressionModel in project jpmml-r by jpmml.

the class MVRConverter method encodeModel.

@Override
public GeneralRegressionModel encodeModel(Schema schema) {
    RGenericVector mvr = getObject();
    RDoubleVector coefficients = (RDoubleVector) mvr.getValue("coefficients");
    RDoubleVector xMeans = (RDoubleVector) mvr.getValue("Xmeans");
    RDoubleVector yMeans = (RDoubleVector) mvr.getValue("Ymeans");
    RNumberVector<?> ncomp = (RNumberVector<?>) mvr.getValue("ncomp");
    RStringVector rowNames = coefficients.dimnames(0);
    RStringVector columnNames = coefficients.dimnames(1);
    RStringVector compNames = coefficients.dimnames(2);
    int rows = rowNames.size();
    int columns = columnNames.size();
    int components = compNames.size();
    List<? extends Feature> features = schema.getFeatures();
    List<Double> featureCoefficients = FortranMatrixUtil.getColumn(coefficients.getValues(), rows, (columns * components), 0 + (ValueUtil.asInt(ncomp.asScalar()) - 1));
    Double intercept = yMeans.getValue(0);
    for (int j = 0; j < rowNames.size(); j++) {
        intercept -= (featureCoefficients.get(j) * xMeans.getValue(j));
    }
    GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), null, null, null).setLinkFunction(GeneralRegressionModel.LinkFunction.IDENTITY);
    GeneralRegressionModelUtil.encodeRegressionTable(generalRegressionModel, features, intercept, featureCoefficients, null);
    return generalRegressionModel;
}
Also used : GeneralRegressionModel(org.dmg.pmml.general_regression.GeneralRegressionModel)

Example 5 with GeneralRegressionModel

use of org.dmg.pmml.general_regression.GeneralRegressionModel in project jpmml-sparkml by jpmml.

the class GeneralizedLinearRegressionModelConverter method encodeModel.

@Override
public GeneralRegressionModel encodeModel(Schema schema) {
    GeneralizedLinearRegressionModel model = getTransformer();
    String targetCategory = null;
    MiningFunction miningFunction = getMiningFunction();
    switch(miningFunction) {
        case CLASSIFICATION:
            CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
            if (categoricalLabel.size() != 2) {
                throw new IllegalArgumentException();
            }
            targetCategory = categoricalLabel.getValue(1);
            break;
        default:
            break;
    }
    GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), null, null, null).setDistribution(parseFamily(model.getFamily())).setLinkFunction(parseLinkFunction(model.getLink())).setLinkParameter(parseLinkParameter(model.getLink()));
    GeneralRegressionModelUtil.encodeRegressionTable(generalRegressionModel, schema.getFeatures(), model.intercept(), VectorUtil.toList(model.coefficients()), targetCategory);
    return generalRegressionModel;
}
Also used : GeneralizedLinearRegressionModel(org.apache.spark.ml.regression.GeneralizedLinearRegressionModel) CategoricalLabel(org.jpmml.converter.CategoricalLabel) GeneralRegressionModel(org.dmg.pmml.general_regression.GeneralRegressionModel) MiningFunction(org.dmg.pmml.MiningFunction)

Aggregations

GeneralRegressionModel (org.dmg.pmml.general_regression.GeneralRegressionModel)5 CategoricalLabel (org.jpmml.converter.CategoricalLabel)3 MiningFunction (org.dmg.pmml.MiningFunction)2 GeneralizedLinearRegressionModel (org.apache.spark.ml.regression.GeneralizedLinearRegressionModel)1 Label (org.jpmml.converter.Label)1