use of org.dmg.pmml.MiningFunction in project jpmml-r by jpmml.
the class GLMConverter method encodeModel.
@Override
public Model encodeModel(Schema schema) {
RGenericVector glm = getObject();
RDoubleVector coefficients = (RDoubleVector) glm.getValue("coefficients");
RGenericVector family = (RGenericVector) glm.getValue("family");
Double intercept = coefficients.getValue(getInterceptName(), true);
RStringVector familyFamily = (RStringVector) family.getValue("family");
RStringVector familyLink = (RStringVector) family.getValue("link");
Label label = schema.getLabel();
List<? extends Feature> features = schema.getFeatures();
if (coefficients.size() != (features.size() + (intercept != null ? 1 : 0))) {
throw new IllegalArgumentException();
}
List<Double> featureCoefficients = getFeatureCoefficients(features, coefficients);
MiningFunction miningFunction = getMiningFunction(familyFamily.asScalar());
String targetCategory = null;
switch(miningFunction) {
case CLASSIFICATION:
{
CategoricalLabel categoricalLabel = (CategoricalLabel) label;
if (categoricalLabel.size() != 2) {
throw new IllegalArgumentException();
}
targetCategory = categoricalLabel.getValue(1);
}
break;
default:
break;
}
GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, miningFunction, ModelUtil.createMiningSchema(label), null, null, null).setDistribution(parseFamily(familyFamily.asScalar())).setLinkFunction(parseLinkFunction(familyLink.asScalar())).setLinkParameter(parseLinkParameter(familyLink.asScalar()));
GeneralRegressionModelUtil.encodeRegressionTable(generalRegressionModel, features, intercept, featureCoefficients, targetCategory);
switch(miningFunction) {
case CLASSIFICATION:
generalRegressionModel.setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, (CategoricalLabel) label));
break;
default:
break;
}
return generalRegressionModel;
}
use of org.dmg.pmml.MiningFunction in project jpmml-sparkml by jpmml.
the class TreeModelCompactor method visit.
@Override
public VisitorAction visit(TreeModel treeModel) {
TreeModel.MissingValueStrategy missingValueStrategy = treeModel.getMissingValueStrategy();
TreeModel.NoTrueChildStrategy noTrueChildStrategy = treeModel.getNoTrueChildStrategy();
TreeModel.SplitCharacteristic splitCharacteristic = treeModel.getSplitCharacteristic();
if (!(TreeModel.MissingValueStrategy.NONE).equals(missingValueStrategy) || !(TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION).equals(noTrueChildStrategy) || !(TreeModel.SplitCharacteristic.BINARY_SPLIT).equals(splitCharacteristic)) {
throw new IllegalArgumentException();
}
treeModel.setMissingValueStrategy(TreeModel.MissingValueStrategy.NULL_PREDICTION).setSplitCharacteristic(TreeModel.SplitCharacteristic.MULTI_SPLIT);
MiningFunction miningFunction = treeModel.getMiningFunction();
switch(miningFunction) {
case REGRESSION:
treeModel.setNoTrueChildStrategy(TreeModel.NoTrueChildStrategy.RETURN_LAST_PREDICTION);
break;
case CLASSIFICATION:
break;
default:
throw new IllegalArgumentException();
}
return super.visit(treeModel);
}
use of org.dmg.pmml.MiningFunction in project jpmml-sparkml by jpmml.
the class ModelConverter method encodeSchema.
public Schema encodeSchema(SparkMLEncoder encoder) {
T model = getTransformer();
Label label = null;
if (model instanceof HasLabelCol) {
HasLabelCol hasLabelCol = (HasLabelCol) model;
String labelCol = hasLabelCol.getLabelCol();
Feature feature = encoder.getOnlyFeature(labelCol);
MiningFunction miningFunction = getMiningFunction();
switch(miningFunction) {
case CLASSIFICATION:
{
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
DataField dataField = encoder.getDataField(categoricalFeature.getName());
label = new CategoricalLabel(dataField);
} else if (feature instanceof ContinuousFeature) {
ContinuousFeature continuousFeature = (ContinuousFeature) feature;
int numClasses = 2;
if (model instanceof ClassificationModel) {
ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
numClasses = classificationModel.numClasses();
}
List<String> categories = new ArrayList<>();
for (int i = 0; i < numClasses; i++) {
categories.add(String.valueOf(i));
}
Field<?> field = encoder.toCategorical(continuousFeature.getName(), categories);
encoder.putOnlyFeature(labelCol, new CategoricalFeature(encoder, field, categories));
label = new CategoricalLabel(field.getName(), field.getDataType(), categories);
} else {
throw new IllegalArgumentException("Expected a categorical or categorical-like continuous feature, got " + feature);
}
}
break;
case REGRESSION:
{
Field<?> field = encoder.toContinuous(feature.getName());
field.setDataType(DataType.DOUBLE);
label = new ContinuousLabel(field.getName(), field.getDataType());
}
break;
default:
throw new IllegalArgumentException("Mining function " + miningFunction + " is not supported");
}
}
if (model instanceof ClassificationModel) {
ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
CategoricalLabel categoricalLabel = (CategoricalLabel) label;
int numClasses = classificationModel.numClasses();
if (numClasses != categoricalLabel.size()) {
throw new IllegalArgumentException("Expected " + numClasses + " target categories, got " + categoricalLabel.size() + " target categories");
}
}
String featuresCol = model.getFeaturesCol();
List<Feature> features = encoder.getFeatures(featuresCol);
if (model instanceof PredictionModel) {
PredictionModel<?, ?> predictionModel = (PredictionModel<?, ?>) model;
int numFeatures = predictionModel.numFeatures();
if (numFeatures != -1 && features.size() != numFeatures) {
throw new IllegalArgumentException("Expected " + numFeatures + " features, got " + features.size() + " features");
}
}
Schema result = new Schema(label, features);
return result;
}
use of org.dmg.pmml.MiningFunction in project jpmml-r by jpmml.
the class GLMConverter method encodeSchema.
@Override
public void encodeSchema(RExpEncoder encoder) {
RGenericVector glm = getObject();
RGenericVector family = (RGenericVector) glm.getValue("family");
RGenericVector model = (RGenericVector) glm.getValue("model");
RStringVector familyFamily = (RStringVector) family.getValue("family");
super.encodeSchema(encoder);
MiningFunction miningFunction = getMiningFunction(familyFamily.asScalar());
switch(miningFunction) {
case CLASSIFICATION:
Label label = encoder.getLabel();
RIntegerVector variable = (RIntegerVector) model.getValue((label.getName()).getValue());
DataField dataField = (DataField) encoder.toCategorical(label.getName(), RExpUtil.getFactorLevels(variable));
encoder.setLabel(dataField);
break;
default:
break;
}
}
use of org.dmg.pmml.MiningFunction in project jpmml-sparkml by jpmml.
the class GeneralizedLinearRegressionModelConverter method registerOutputFields.
@Override
public List<OutputField> registerOutputFields(Label label, SparkMLEncoder encoder) {
List<OutputField> result = super.registerOutputFields(label, encoder);
MiningFunction miningFunction = getMiningFunction();
switch(miningFunction) {
case CLASSIFICATION:
CategoricalLabel categoricalLabel = (CategoricalLabel) label;
result = new ArrayList<>(result);
result.addAll(ModelUtil.createProbabilityFields(DataType.DOUBLE, categoricalLabel.getValues()));
break;
default:
break;
}
return result;
}
Aggregations