use of org.jpmml.converter.BinaryFeature in project jpmml-sparkml by jpmml.
the class TreeModelUtil method encodeNode.
public static Node encodeNode(org.apache.spark.ml.tree.Node node, PredicateManager predicateManager, Map<FieldName, Set<String>> parentFieldValues, MiningFunction miningFunction, Schema schema) {
if (node instanceof InternalNode) {
InternalNode internalNode = (InternalNode) node;
Map<FieldName, Set<String>> leftFieldValues = parentFieldValues;
Map<FieldName, Set<String>> rightFieldValues = parentFieldValues;
Predicate leftPredicate;
Predicate rightPredicate;
Split split = internalNode.split();
Feature feature = schema.getFeature(split.featureIndex());
if (split instanceof ContinuousSplit) {
ContinuousSplit continuousSplit = (ContinuousSplit) split;
double threshold = continuousSplit.threshold();
if (feature instanceof BooleanFeature) {
BooleanFeature booleanFeature = (BooleanFeature) feature;
if (threshold != 0.5d) {
throw new IllegalArgumentException();
}
leftPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
rightPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
String value = ValueUtil.formatValue(threshold);
leftPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
rightPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
}
} else if (split instanceof CategoricalSplit) {
CategoricalSplit categoricalSplit = (CategoricalSplit) split;
double[] leftCategories = categoricalSplit.leftCategories();
double[] rightCategories = categoricalSplit.rightCategories();
if (feature instanceof BinaryFeature) {
BinaryFeature binaryFeature = (BinaryFeature) feature;
SimplePredicate.Operator leftOperator;
SimplePredicate.Operator rightOperator;
if (Arrays.equals(TRUE, leftCategories) && Arrays.equals(FALSE, rightCategories)) {
leftOperator = SimplePredicate.Operator.EQUAL;
rightOperator = SimplePredicate.Operator.NOT_EQUAL;
} else if (Arrays.equals(FALSE, leftCategories) && Arrays.equals(TRUE, rightCategories)) {
leftOperator = SimplePredicate.Operator.NOT_EQUAL;
rightOperator = SimplePredicate.Operator.EQUAL;
} else {
throw new IllegalArgumentException();
}
String value = ValueUtil.formatValue(binaryFeature.getValue());
leftPredicate = predicateManager.createSimplePredicate(binaryFeature, leftOperator, value);
rightPredicate = predicateManager.createSimplePredicate(binaryFeature, rightOperator, value);
} else if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
FieldName name = categoricalFeature.getName();
List<String> values = categoricalFeature.getValues();
if (values.size() != (leftCategories.length + rightCategories.length)) {
throw new IllegalArgumentException();
}
final Set<String> parentValues = parentFieldValues.get(name);
com.google.common.base.Predicate<String> valueFilter = new com.google.common.base.Predicate<String>() {
@Override
public boolean apply(String value) {
if (parentValues != null) {
return parentValues.contains(value);
}
return true;
}
};
List<String> leftValues = selectValues(values, leftCategories, valueFilter);
List<String> rightValues = selectValues(values, rightCategories, valueFilter);
leftFieldValues = new HashMap<>(parentFieldValues);
leftFieldValues.put(name, new HashSet<>(leftValues));
rightFieldValues = new HashMap<>(parentFieldValues);
rightFieldValues.put(name, new HashSet<>(rightValues));
leftPredicate = predicateManager.createSimpleSetPredicate(categoricalFeature, leftValues);
rightPredicate = predicateManager.createSimpleSetPredicate(categoricalFeature, rightValues);
} else {
throw new IllegalArgumentException();
}
} else {
throw new IllegalArgumentException();
}
Node result = new Node();
Node leftChild = encodeNode(internalNode.leftChild(), predicateManager, leftFieldValues, miningFunction, schema).setPredicate(leftPredicate);
Node rightChild = encodeNode(internalNode.rightChild(), predicateManager, rightFieldValues, miningFunction, schema).setPredicate(rightPredicate);
result.addNodes(leftChild, rightChild);
return result;
} else if (node instanceof LeafNode) {
LeafNode leafNode = (LeafNode) node;
Node result = new Node();
switch(miningFunction) {
case REGRESSION:
{
String score = ValueUtil.formatValue(node.prediction());
result.setScore(score);
}
break;
case CLASSIFICATION:
{
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
int index = ValueUtil.asInt(node.prediction());
result.setScore(categoricalLabel.getValue(index));
ImpurityCalculator impurityCalculator = node.impurityStats();
result.setRecordCount((double) impurityCalculator.count());
double[] stats = impurityCalculator.stats();
for (int i = 0; i < stats.length; i++) {
ScoreDistribution scoreDistribution = new ScoreDistribution(categoricalLabel.getValue(i), stats[i]);
result.addScoreDistributions(scoreDistribution);
}
}
break;
default:
throw new UnsupportedOperationException();
}
return result;
} else {
throw new IllegalArgumentException();
}
}
use of org.jpmml.converter.BinaryFeature in project jpmml-sparkml by jpmml.
the class OneHotEncoderModelConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
OneHotEncoderModel transformer = getTransformer();
String[] inputCols = transformer.getInputCols();
boolean dropLast = transformer.getDropLast();
List<Feature> result = new ArrayList<>();
for (int i = 0; i < inputCols.length; i++) {
CategoricalFeature categoricalFeature = (CategoricalFeature) encoder.getOnlyFeature(inputCols[i]);
List<String> values = categoricalFeature.getValues();
if (dropLast) {
values = values.subList(0, values.size() - 1);
}
List<BinaryFeature> binaryFeatures = new ArrayList<>();
for (String value : values) {
binaryFeatures.add(new BinaryFeature(encoder, categoricalFeature.getName(), DataType.STRING, value));
}
result.add(new BinarizedCategoricalFeature(encoder, categoricalFeature.getName(), categoricalFeature.getDataType(), binaryFeatures));
}
return result;
}
use of org.jpmml.converter.BinaryFeature in project jpmml-r by jpmml.
the class Formula method addField.
public void addField(Field<?> field, List<String> categoryNames, List<String> categoryValues) {
RExpEncoder encoder = getEncoder();
if (categoryNames.size() != categoryValues.size()) {
throw new IllegalArgumentException();
}
CategoricalFeature categoricalFeature;
if ((DataType.BOOLEAN).equals(field.getDataType()) && (categoryValues.size() == 2) && ("false").equals(categoryValues.get(0)) && ("true").equals(categoryValues.get(1))) {
categoricalFeature = new BooleanFeature(encoder, field);
} else {
categoricalFeature = new CategoricalFeature(encoder, field, categoryValues);
}
putFeature(field.getName(), categoricalFeature);
for (int i = 0; i < categoryNames.size(); i++) {
String categoryName = categoryNames.get(i);
String categoryValue = categoryValues.get(i);
BinaryFeature binaryFeature = new BinaryFeature(encoder, field, categoryValue);
putFeature(FieldName.create((field.getName()).getValue() + categoryName), binaryFeature);
}
this.fields.add(field);
}
use of org.jpmml.converter.BinaryFeature in project jpmml-r by jpmml.
the class ScorecardConverter method encodeModel.
@Override
public Scorecard encodeModel(Schema schema) {
RGenericVector glm = getObject();
RDoubleVector coefficients = (RDoubleVector) glm.getValue("coefficients");
RGenericVector family = (RGenericVector) glm.getValue("family");
RGenericVector scConf;
try {
scConf = (RGenericVector) glm.getValue("sc.conf");
} catch (IllegalArgumentException iae) {
throw new IllegalArgumentException("No scorecard configuration information. Please initialize the \'sc.conf\' element", iae);
}
Double intercept = coefficients.getValue(LMConverter.INTERCEPT, true);
List<? extends Feature> features = schema.getFeatures();
if (coefficients.size() != (features.size() + (intercept != null ? 1 : 0))) {
throw new IllegalArgumentException();
}
RNumberVector<?> odds = (RNumberVector<?>) scConf.getValue("odds");
RNumberVector<?> basePoints = (RNumberVector<?>) scConf.getValue("base_points");
RNumberVector<?> pdo = (RNumberVector<?>) scConf.getValue("pdo");
double factor = (pdo.asScalar()).doubleValue() / Math.log(2);
Map<FieldName, Characteristic> fieldCharacteristics = new LinkedHashMap<>();
for (Feature feature : features) {
FieldName name = feature.getName();
if (!(feature instanceof BinaryFeature)) {
throw new IllegalArgumentException();
}
Double coefficient = getFeatureCoefficient(feature, coefficients);
Characteristic characteristic = fieldCharacteristics.get(name);
if (characteristic == null) {
characteristic = new Characteristic().setName(FeatureUtil.createName("score", feature));
fieldCharacteristics.put(name, characteristic);
}
BinaryFeature binaryFeature = (BinaryFeature) feature;
SimplePredicate simplePredicate = new SimplePredicate().setField(binaryFeature.getName()).setOperator(SimplePredicate.Operator.EQUAL).setValue(binaryFeature.getValue());
Attribute attribute = new Attribute().setPartialScore(formatScore(-1d * coefficient * factor)).setPredicate(simplePredicate);
characteristic.addAttributes(attribute);
}
Characteristics characteristics = new Characteristics();
Collection<Map.Entry<FieldName, Characteristic>> entries = fieldCharacteristics.entrySet();
for (Map.Entry<FieldName, Characteristic> entry : entries) {
Characteristic characteristic = entry.getValue();
Attribute attribute = new Attribute().setPartialScore(0d).setPredicate(new True());
characteristic.addAttributes(attribute);
characteristics.addCharacteristics(characteristic);
}
Scorecard scorecard = new Scorecard(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), characteristics).setInitialScore(formatScore((basePoints.asScalar()).doubleValue() - Math.log((odds.asScalar()).doubleValue()) * factor - (intercept != null ? intercept * factor : 0))).setUseReasonCodes(false);
return scorecard;
}
use of org.jpmml.converter.BinaryFeature in project jpmml-sparkml by jpmml.
the class OneHotEncoderConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
OneHotEncoder transformer = getTransformer();
boolean dropLast = true;
Option<Object> dropLastOption = transformer.get(transformer.dropLast());
if (dropLastOption.isDefined()) {
dropLast = (Boolean) dropLastOption.get();
}
CategoricalFeature categoricalFeature = (CategoricalFeature) encoder.getOnlyFeature(transformer.getInputCol());
List<String> values = categoricalFeature.getValues();
if (dropLast) {
values = values.subList(0, values.size() - 1);
}
List<Feature> result = new ArrayList<>();
for (String value : values) {
result.add(new BinaryFeature(encoder, categoricalFeature.getName(), DataType.STRING, value));
}
return result;
}
Aggregations