use of org.jpmml.converter.CategoricalLabel in project jpmml-sparkml by jpmml.
the class LogisticRegressionModelConverter method encodeModel.
@Override
public RegressionModel encodeModel(Schema schema) {
LogisticRegressionModel model = getTransformer();
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
if (categoricalLabel.size() == 2) {
RegressionModel regressionModel = RegressionModelUtil.createBinaryLogisticClassification(schema.getFeatures(), VectorUtil.toList(model.coefficients()), model.intercept(), RegressionModel.NormalizationMethod.LOGIT, true, schema).setOutput(null);
return regressionModel;
} else if (categoricalLabel.size() > 2) {
Matrix coefficientMatrix = model.coefficientMatrix();
Vector interceptVector = model.interceptVector();
List<? extends Feature> features = schema.getFeatures();
List<RegressionTable> regressionTables = new ArrayList<>();
for (int i = 0; i < categoricalLabel.size(); i++) {
RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(features, MatrixUtil.getRow(coefficientMatrix, i), interceptVector.apply(i)).setTargetCategory(categoricalLabel.getValue(i));
regressionTables.add(regressionTable);
}
RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), regressionTables).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX);
return regressionModel;
} else {
throw new IllegalArgumentException();
}
}
use of org.jpmml.converter.CategoricalLabel in project jpmml-sparkml by jpmml.
the class MultilayerPerceptronClassificationModelConverter method registerOutputFields.
@Override
public List<OutputField> registerOutputFields(Label label, SparkMLEncoder encoder) {
MultilayerPerceptronClassificationModel model = getTransformer();
List<OutputField> result = super.registerOutputFields(label, encoder);
if (!(model instanceof HasProbabilityCol)) {
CategoricalLabel categoricalLabel = (CategoricalLabel) label;
result = new ArrayList<>(result);
result.addAll(ModelUtil.createProbabilityFields(DataType.DOUBLE, categoricalLabel.getValues()));
}
return result;
}
use of org.jpmml.converter.CategoricalLabel in project jpmml-sparkml by jpmml.
the class TreeModelUtil method encodeNode.
public static Node encodeNode(org.apache.spark.ml.tree.Node node, PredicateManager predicateManager, Map<FieldName, Set<String>> parentFieldValues, MiningFunction miningFunction, Schema schema) {
if (node instanceof InternalNode) {
InternalNode internalNode = (InternalNode) node;
Map<FieldName, Set<String>> leftFieldValues = parentFieldValues;
Map<FieldName, Set<String>> rightFieldValues = parentFieldValues;
Predicate leftPredicate;
Predicate rightPredicate;
Split split = internalNode.split();
Feature feature = schema.getFeature(split.featureIndex());
if (split instanceof ContinuousSplit) {
ContinuousSplit continuousSplit = (ContinuousSplit) split;
double threshold = continuousSplit.threshold();
if (feature instanceof BooleanFeature) {
BooleanFeature booleanFeature = (BooleanFeature) feature;
if (threshold != 0.5d) {
throw new IllegalArgumentException();
}
leftPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
rightPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
String value = ValueUtil.formatValue(threshold);
leftPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
rightPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
}
} else if (split instanceof CategoricalSplit) {
CategoricalSplit categoricalSplit = (CategoricalSplit) split;
double[] leftCategories = categoricalSplit.leftCategories();
double[] rightCategories = categoricalSplit.rightCategories();
if (feature instanceof BinaryFeature) {
BinaryFeature binaryFeature = (BinaryFeature) feature;
SimplePredicate.Operator leftOperator;
SimplePredicate.Operator rightOperator;
if (Arrays.equals(TRUE, leftCategories) && Arrays.equals(FALSE, rightCategories)) {
leftOperator = SimplePredicate.Operator.EQUAL;
rightOperator = SimplePredicate.Operator.NOT_EQUAL;
} else if (Arrays.equals(FALSE, leftCategories) && Arrays.equals(TRUE, rightCategories)) {
leftOperator = SimplePredicate.Operator.NOT_EQUAL;
rightOperator = SimplePredicate.Operator.EQUAL;
} else {
throw new IllegalArgumentException();
}
String value = ValueUtil.formatValue(binaryFeature.getValue());
leftPredicate = predicateManager.createSimplePredicate(binaryFeature, leftOperator, value);
rightPredicate = predicateManager.createSimplePredicate(binaryFeature, rightOperator, value);
} else if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
FieldName name = categoricalFeature.getName();
List<String> values = categoricalFeature.getValues();
if (values.size() != (leftCategories.length + rightCategories.length)) {
throw new IllegalArgumentException();
}
final Set<String> parentValues = parentFieldValues.get(name);
com.google.common.base.Predicate<String> valueFilter = new com.google.common.base.Predicate<String>() {
@Override
public boolean apply(String value) {
if (parentValues != null) {
return parentValues.contains(value);
}
return true;
}
};
List<String> leftValues = selectValues(values, leftCategories, valueFilter);
List<String> rightValues = selectValues(values, rightCategories, valueFilter);
leftFieldValues = new HashMap<>(parentFieldValues);
leftFieldValues.put(name, new HashSet<>(leftValues));
rightFieldValues = new HashMap<>(parentFieldValues);
rightFieldValues.put(name, new HashSet<>(rightValues));
leftPredicate = predicateManager.createSimpleSetPredicate(categoricalFeature, leftValues);
rightPredicate = predicateManager.createSimpleSetPredicate(categoricalFeature, rightValues);
} else {
throw new IllegalArgumentException();
}
} else {
throw new IllegalArgumentException();
}
Node result = new Node();
Node leftChild = encodeNode(internalNode.leftChild(), predicateManager, leftFieldValues, miningFunction, schema).setPredicate(leftPredicate);
Node rightChild = encodeNode(internalNode.rightChild(), predicateManager, rightFieldValues, miningFunction, schema).setPredicate(rightPredicate);
result.addNodes(leftChild, rightChild);
return result;
} else if (node instanceof LeafNode) {
LeafNode leafNode = (LeafNode) node;
Node result = new Node();
switch(miningFunction) {
case REGRESSION:
{
String score = ValueUtil.formatValue(node.prediction());
result.setScore(score);
}
break;
case CLASSIFICATION:
{
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
int index = ValueUtil.asInt(node.prediction());
result.setScore(categoricalLabel.getValue(index));
ImpurityCalculator impurityCalculator = node.impurityStats();
result.setRecordCount((double) impurityCalculator.count());
double[] stats = impurityCalculator.stats();
for (int i = 0; i < stats.length; i++) {
ScoreDistribution scoreDistribution = new ScoreDistribution(categoricalLabel.getValue(i), stats[i]);
result.addScoreDistributions(scoreDistribution);
}
}
break;
default:
throw new UnsupportedOperationException();
}
return result;
} else {
throw new IllegalArgumentException();
}
}
use of org.jpmml.converter.CategoricalLabel in project jpmml-sparkml by jpmml.
the class ModelConverter method encodeSchema.
public Schema encodeSchema(SparkMLEncoder encoder) {
T model = getTransformer();
Label label = null;
if (model instanceof HasLabelCol) {
HasLabelCol hasLabelCol = (HasLabelCol) model;
String labelCol = hasLabelCol.getLabelCol();
Feature feature = encoder.getOnlyFeature(labelCol);
MiningFunction miningFunction = getMiningFunction();
switch(miningFunction) {
case CLASSIFICATION:
{
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
DataField dataField = encoder.getDataField(categoricalFeature.getName());
label = new CategoricalLabel(dataField);
} else if (feature instanceof ContinuousFeature) {
ContinuousFeature continuousFeature = (ContinuousFeature) feature;
int numClasses = 2;
if (model instanceof ClassificationModel) {
ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
numClasses = classificationModel.numClasses();
}
List<String> categories = new ArrayList<>();
for (int i = 0; i < numClasses; i++) {
categories.add(String.valueOf(i));
}
Field<?> field = encoder.toCategorical(continuousFeature.getName(), categories);
encoder.putOnlyFeature(labelCol, new CategoricalFeature(encoder, field, categories));
label = new CategoricalLabel(field.getName(), field.getDataType(), categories);
} else {
throw new IllegalArgumentException("Expected a categorical or categorical-like continuous feature, got " + feature);
}
}
break;
case REGRESSION:
{
Field<?> field = encoder.toContinuous(feature.getName());
field.setDataType(DataType.DOUBLE);
label = new ContinuousLabel(field.getName(), field.getDataType());
}
break;
default:
throw new IllegalArgumentException("Mining function " + miningFunction + " is not supported");
}
}
if (model instanceof ClassificationModel) {
ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
CategoricalLabel categoricalLabel = (CategoricalLabel) label;
int numClasses = classificationModel.numClasses();
if (numClasses != categoricalLabel.size()) {
throw new IllegalArgumentException("Expected " + numClasses + " target categories, got " + categoricalLabel.size() + " target categories");
}
}
String featuresCol = model.getFeaturesCol();
List<Feature> features = encoder.getFeatures(featuresCol);
if (model instanceof PredictionModel) {
PredictionModel<?, ?> predictionModel = (PredictionModel<?, ?>) model;
int numFeatures = predictionModel.numFeatures();
if (numFeatures != -1 && features.size() != numFeatures) {
throw new IllegalArgumentException("Expected " + numFeatures + " features, got " + features.size() + " features");
}
}
Schema result = new Schema(label, features);
return result;
}
use of org.jpmml.converter.CategoricalLabel in project pyramid by cheng-li.
the class MiningModelUtil method createClassification.
public static MiningModel createClassification(List<? extends Model> models, RegressionModel.NormalizationMethod normalizationMethod, boolean hasProbabilityDistribution, Schema schema) {
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
// modified here
if (categoricalLabel.size() != models.size()) {
throw new IllegalArgumentException();
}
if (normalizationMethod != null) {
switch(normalizationMethod) {
case NONE:
case SIMPLEMAX:
case SOFTMAX:
break;
default:
throw new IllegalArgumentException();
}
}
MathContext mathContext = null;
List<RegressionTable> regressionTables = new ArrayList<>();
for (int i = 0; i < categoricalLabel.size(); i++) {
Model model = models.get(i);
MathContext modelMathContext = model.getMathContext();
if (modelMathContext == null) {
modelMathContext = MathContext.DOUBLE;
}
if (mathContext == null) {
mathContext = modelMathContext;
} else {
if (!Objects.equals(mathContext, modelMathContext)) {
throw new IllegalArgumentException();
}
}
Feature feature = MODEL_PREDICTION.apply(model);
RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(Collections.singletonList(feature), Collections.singletonList(1d), null).setTargetCategory(categoricalLabel.getValue(i));
regressionTables.add(regressionTable);
}
RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), regressionTables).setNormalizationMethod(normalizationMethod).setMathContext(ModelUtil.simplifyMathContext(mathContext)).setOutput(hasProbabilityDistribution ? ModelUtil.createProbabilityOutput(mathContext, categoricalLabel) : null);
List<Model> segmentationModels = new ArrayList<>(models);
segmentationModels.add(regressionModel);
return createModelChain(segmentationModels, schema);
}
Aggregations