use of org.jpmml.converter.CategoricalFeature in project jpmml-sparkml by jpmml.
the class OneHotEncoderModelConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
OneHotEncoderModel transformer = getTransformer();
String[] inputCols = transformer.getInputCols();
boolean dropLast = transformer.getDropLast();
List<Feature> result = new ArrayList<>();
for (int i = 0; i < inputCols.length; i++) {
CategoricalFeature categoricalFeature = (CategoricalFeature) encoder.getOnlyFeature(inputCols[i]);
List<String> values = categoricalFeature.getValues();
if (dropLast) {
values = values.subList(0, values.size() - 1);
}
List<BinaryFeature> binaryFeatures = new ArrayList<>();
for (String value : values) {
binaryFeatures.add(new BinaryFeature(encoder, categoricalFeature.getName(), DataType.STRING, value));
}
result.add(new BinarizedCategoricalFeature(encoder, categoricalFeature.getName(), categoricalFeature.getDataType(), binaryFeatures));
}
return result;
}
use of org.jpmml.converter.CategoricalFeature in project jpmml-sparkml by jpmml.
the class BinarizerConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
Binarizer transformer = getTransformer();
Feature feature = encoder.getOnlyFeature(transformer.getInputCol());
ContinuousFeature continuousFeature = feature.toContinuousFeature();
Apply apply = new Apply("if").addExpressions(PMMLUtil.createApply("lessOrEqual", continuousFeature.ref(), PMMLUtil.createConstant(transformer.getThreshold()))).addExpressions(PMMLUtil.createConstant(0d), PMMLUtil.createConstant(1d));
DerivedField derivedField = encoder.createDerivedField(formatName(transformer), OpType.CATEGORICAL, DataType.DOUBLE, apply);
return Collections.<Feature>singletonList(new CategoricalFeature(encoder, derivedField, Arrays.asList("0", "1")));
}
use of org.jpmml.converter.CategoricalFeature in project jpmml-r by jpmml.
the class GBMConverter method encodeNode.
private void encodeNode(Node node, int i, RGenericVector tree, RGenericVector c_splits, Schema schema) {
RIntegerVector splitVar = (RIntegerVector) tree.getValue(0);
RDoubleVector splitCodePred = (RDoubleVector) tree.getValue(1);
RIntegerVector leftNode = (RIntegerVector) tree.getValue(2);
RIntegerVector rightNode = (RIntegerVector) tree.getValue(3);
RIntegerVector missingNode = (RIntegerVector) tree.getValue(4);
RDoubleVector prediction = (RDoubleVector) tree.getValue(7);
Predicate missingPredicate;
Predicate leftPredicate;
Predicate rightPredicate;
Integer var = splitVar.getValue(i);
if (var != -1) {
Feature feature = schema.getFeature(var);
missingPredicate = createSimplePredicate(feature, SimplePredicate.Operator.IS_MISSING, null);
Double split = splitCodePred.getValue(i);
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
List<String> values = categoricalFeature.getValues();
int index = ValueUtil.asInt(split);
RIntegerVector c_split = (RIntegerVector) c_splits.getValue(index);
List<Integer> splitValues = c_split.getValues();
leftPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, true));
rightPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, false));
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
String value = ValueUtil.formatValue(split);
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, value);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, value);
}
} else {
Double value = prediction.getValue(i);
node.setScore(ValueUtil.formatValue(value));
return;
}
Integer missing = missingNode.getValue(i);
if (missing != -1) {
Node missingChild = new Node().setId(String.valueOf(missing + 1)).setPredicate(missingPredicate);
encodeNode(missingChild, missing, tree, c_splits, schema);
node.addNodes(missingChild);
}
Integer left = leftNode.getValue(i);
if (left != -1) {
Node leftChild = new Node().setId(String.valueOf(left + 1)).setPredicate(leftPredicate);
encodeNode(leftChild, left, tree, c_splits, schema);
node.addNodes(leftChild);
}
Integer right = rightNode.getValue(i);
if (right != -1) {
Node rightChild = new Node().setId(String.valueOf(right + 1)).setPredicate(rightPredicate);
encodeNode(rightChild, right, tree, c_splits, schema);
node.addNodes(rightChild);
}
}
use of org.jpmml.converter.CategoricalFeature in project jpmml-r by jpmml.
the class BinaryTreeConverter method encodeNode.
private void encodeNode(Node node, RGenericVector tree, Schema schema) {
RIntegerVector nodeId = (RIntegerVector) tree.getValue("nodeID");
RBooleanVector terminal = (RBooleanVector) tree.getValue("terminal");
RGenericVector psplit = (RGenericVector) tree.getValue("psplit");
RGenericVector ssplits = (RGenericVector) tree.getValue("ssplits");
RDoubleVector prediction = (RDoubleVector) tree.getValue("prediction");
RGenericVector left = (RGenericVector) tree.getValue("left");
RGenericVector right = (RGenericVector) tree.getValue("right");
node.setId(String.valueOf(nodeId.asScalar()));
if ((Boolean.TRUE).equals(terminal.asScalar())) {
node = encodeScore(node, prediction, schema);
return;
}
RNumberVector<?> splitpoint = (RNumberVector<?>) psplit.getValue("splitpoint");
RStringVector variableName = (RStringVector) psplit.getValue("variableName");
if (ssplits.size() > 0) {
throw new IllegalArgumentException();
}
Predicate leftPredicate;
Predicate rightPredicate;
FieldName name = FieldName.create(variableName.asScalar());
Integer index = this.featureIndexes.get(name);
if (index == null) {
throw new IllegalArgumentException();
}
Feature feature = schema.getFeature(index);
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
List<String> values = categoricalFeature.getValues();
List<Integer> splitValues = (List<Integer>) splitpoint.getValues();
leftPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, true));
rightPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, false));
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
String value = ValueUtil.formatValue((Double) splitpoint.asScalar());
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
}
Node leftChild = new Node().setPredicate(leftPredicate);
encodeNode(leftChild, left, schema);
Node rightChild = new Node().setPredicate(rightPredicate);
encodeNode(rightChild, right, schema);
node.addNodes(leftChild, rightChild);
}
use of org.jpmml.converter.CategoricalFeature in project jpmml-r by jpmml.
the class Formula method addField.
public void addField(Field<?> field, List<String> categoryNames, List<String> categoryValues) {
RExpEncoder encoder = getEncoder();
if (categoryNames.size() != categoryValues.size()) {
throw new IllegalArgumentException();
}
CategoricalFeature categoricalFeature;
if ((DataType.BOOLEAN).equals(field.getDataType()) && (categoryValues.size() == 2) && ("false").equals(categoryValues.get(0)) && ("true").equals(categoryValues.get(1))) {
categoricalFeature = new BooleanFeature(encoder, field);
} else {
categoricalFeature = new CategoricalFeature(encoder, field, categoryValues);
}
putFeature(field.getName(), categoricalFeature);
for (int i = 0; i < categoryNames.size(); i++) {
String categoryName = categoryNames.get(i);
String categoryValue = categoryValues.get(i);
BinaryFeature binaryFeature = new BinaryFeature(encoder, field, categoryValue);
putFeature(FieldName.create((field.getName()).getValue() + categoryName), binaryFeature);
}
this.fields.add(field);
}
Aggregations