use of org.jpmml.converter.CategoricalLabel in project jpmml-sparkml by jpmml.
the class ModelConverter method encodeSchema.
public Schema encodeSchema(SparkMLEncoder encoder) {
T model = getTransformer();
Label label = null;
if (model instanceof HasLabelCol) {
HasLabelCol hasLabelCol = (HasLabelCol) model;
String labelCol = hasLabelCol.getLabelCol();
Feature feature = encoder.getOnlyFeature(labelCol);
MiningFunction miningFunction = getMiningFunction();
switch(miningFunction) {
case CLASSIFICATION:
{
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
DataField dataField = encoder.getDataField(categoricalFeature.getName());
label = new CategoricalLabel(dataField);
} else if (feature instanceof ContinuousFeature) {
ContinuousFeature continuousFeature = (ContinuousFeature) feature;
int numClasses = 2;
if (model instanceof ClassificationModel) {
ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
numClasses = classificationModel.numClasses();
}
List<String> categories = new ArrayList<>();
for (int i = 0; i < numClasses; i++) {
categories.add(String.valueOf(i));
}
Field<?> field = encoder.toCategorical(continuousFeature.getName(), categories);
encoder.putOnlyFeature(labelCol, new CategoricalFeature(encoder, field, categories));
label = new CategoricalLabel(field.getName(), field.getDataType(), categories);
} else {
throw new IllegalArgumentException("Expected a categorical or categorical-like continuous feature, got " + feature);
}
}
break;
case REGRESSION:
{
Field<?> field = encoder.toContinuous(feature.getName());
field.setDataType(DataType.DOUBLE);
label = new ContinuousLabel(field.getName(), field.getDataType());
}
break;
default:
throw new IllegalArgumentException("Mining function " + miningFunction + " is not supported");
}
}
if (model instanceof ClassificationModel) {
ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
CategoricalLabel categoricalLabel = (CategoricalLabel) label;
int numClasses = classificationModel.numClasses();
if (numClasses != categoricalLabel.size()) {
throw new IllegalArgumentException("Expected " + numClasses + " target categories, got " + categoricalLabel.size() + " target categories");
}
}
String featuresCol = model.getFeaturesCol();
List<Feature> features = encoder.getFeatures(featuresCol);
if (model instanceof PredictionModel) {
PredictionModel<?, ?> predictionModel = (PredictionModel<?, ?>) model;
int numFeatures = predictionModel.numFeatures();
if (numFeatures != -1 && features.size() != numFeatures) {
throw new IllegalArgumentException("Expected " + numFeatures + " features, got " + features.size() + " features");
}
}
Schema result = new Schema(label, features);
return result;
}
use of org.jpmml.converter.CategoricalLabel in project jpmml-r by jpmml.
the class GBMConverter method encodeMultinomialClassification.
private MiningModel encodeMultinomialClassification(List<TreeModel> treeModels, Double initF, Schema schema) {
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
Schema segmentSchema = new Schema(new ContinuousLabel(null, DataType.DOUBLE), schema.getFeatures());
List<Model> miningModels = new ArrayList<>();
for (int i = 0, columns = categoricalLabel.size(), rows = (treeModels.size() / columns); i < columns; i++) {
MiningModel miningModel = createMiningModel(CMatrixUtil.getColumn(treeModels, rows, columns, i), initF, segmentSchema).setOutput(ModelUtil.createPredictedOutput(FieldName.create("gbmValue(" + categoricalLabel.getValue(i) + ")"), OpType.CONTINUOUS, DataType.DOUBLE));
miningModels.add(miningModel);
}
return MiningModelUtil.createClassification(miningModels, RegressionModel.NormalizationMethod.SOFTMAX, true, schema);
}
use of org.jpmml.converter.CategoricalLabel in project jpmml-r by jpmml.
the class LRMConverter method encodeModel.
@Override
public Model encodeModel(Schema schema) {
RGenericVector lrm = getObject();
RDoubleVector coefficients = (RDoubleVector) lrm.getValue("coefficients");
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
if (categoricalLabel.size() != 2) {
throw new IllegalArgumentException();
}
String targetCategory = categoricalLabel.getValue(1);
Double intercept = coefficients.getValue(getInterceptName(), true);
List<? extends Feature> features = schema.getFeatures();
if (coefficients.size() != (features.size() + (intercept != null ? 1 : 0))) {
throw new IllegalArgumentException();
}
List<Double> featureCoefficients = getFeatureCoefficients(features, coefficients);
GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), null, null, null).setLinkFunction(GeneralRegressionModel.LinkFunction.LOGIT).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
GeneralRegressionModelUtil.encodeRegressionTable(generalRegressionModel, features, intercept, featureCoefficients, targetCategory);
return generalRegressionModel;
}
use of org.jpmml.converter.CategoricalLabel in project jpmml-r by jpmml.
the class BinaryTreeConverter method encodeClassificationScore.
private static Node encodeClassificationScore(Node node, RDoubleVector probabilities, Schema schema) {
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
if (categoricalLabel.size() != probabilities.size()) {
throw new IllegalArgumentException();
}
Double maxProbability = null;
for (int i = 0; i < categoricalLabel.size(); i++) {
String value = categoricalLabel.getValue(i);
Double probability = probabilities.getValue(i);
if (maxProbability == null || (maxProbability).compareTo(probability) < 0) {
node.setScore(value);
maxProbability = probability;
}
ScoreDistribution scoreDistribution = new ScoreDistribution(value, probability);
node.addScoreDistributions(scoreDistribution);
}
return node;
}
use of org.jpmml.converter.CategoricalLabel in project jpmml-r by jpmml.
the class RangerConverter method encodeProbabilityForest.
private MiningModel encodeProbabilityForest(RGenericVector ranger, Schema schema) {
RGenericVector forest = (RGenericVector) ranger.getValue("forest");
final RStringVector levels = (RStringVector) forest.getValue("levels");
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
ScoreEncoder scoreEncoder = new ScoreEncoder() {
@Override
public void encode(Node node, Number splitValue, RNumberVector<?> terminalClassCount) {
if (splitValue.doubleValue() != 0d || (terminalClassCount == null || terminalClassCount.size() != levels.size())) {
throw new IllegalArgumentException();
}
Double maxProbability = null;
for (int i = 0; i < terminalClassCount.size(); i++) {
String value = levels.getValue(i);
Double probability = ValueUtil.asDouble(terminalClassCount.getValue(i));
if (maxProbability == null || (maxProbability).compareTo(probability) < 0) {
node.setScore(value);
maxProbability = probability;
}
ScoreDistribution scoreDisctibution = new ScoreDistribution(value, probability);
node.addScoreDistributions(scoreDisctibution);
}
}
};
List<TreeModel> treeModels = encodeForest(forest, MiningFunction.CLASSIFICATION, scoreEncoder, schema);
MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel)).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels)).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
return miningModel;
}
Aggregations