use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.
the class StringIndexerModelConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
StringIndexerModel transformer = getTransformer();
Feature feature = encoder.getOnlyFeature(transformer.getInputCol());
List<String> categories = new ArrayList<>();
categories.addAll(Arrays.asList(transformer.labels()));
String handleInvalid = transformer.getHandleInvalid();
Field<?> field = encoder.toCategorical(feature.getName(), categories);
if (field instanceof DataField) {
DataField dataField = (DataField) field;
InvalidValueTreatmentMethod invalidValueTreatmentMethod;
switch(handleInvalid) {
case "keep":
invalidValueTreatmentMethod = InvalidValueTreatmentMethod.AS_IS;
break;
case "error":
invalidValueTreatmentMethod = InvalidValueTreatmentMethod.RETURN_INVALID;
break;
default:
throw new IllegalArgumentException(handleInvalid);
}
InvalidValueDecorator invalidValueDecorator = new InvalidValueDecorator().setInvalidValueTreatment(invalidValueTreatmentMethod);
encoder.addDecorator(dataField.getName(), invalidValueDecorator);
} else if (field instanceof DerivedField) {
// Ignored
} else {
throw new IllegalArgumentException();
}
switch(handleInvalid) {
case "keep":
Apply setApply = PMMLUtil.createApply("isIn", feature.ref());
for (String category : categories) {
setApply.addExpressions(PMMLUtil.createConstant(category, feature.getDataType()));
}
categories.add(StringIndexerModelConverter.LABEL_UNKNOWN);
Apply apply = PMMLUtil.createApply("if", setApply, feature.ref(), PMMLUtil.createConstant(StringIndexerModelConverter.LABEL_UNKNOWN, DataType.STRING));
field = encoder.createDerivedField(FeatureUtil.createName("handleInvalid", feature), OpType.CATEGORICAL, feature.getDataType(), apply);
break;
default:
break;
}
return Collections.<Feature>singletonList(new CategoricalFeature(encoder, field, categories));
}
use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.
the class TokenizerConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
Tokenizer transformer = getTransformer();
Feature feature = encoder.getOnlyFeature(transformer.getInputCol());
Apply apply = PMMLUtil.createApply("lowercase", feature.ref());
DerivedField derivedField = encoder.createDerivedField(FeatureUtil.createName("lowercase", feature), OpType.CATEGORICAL, DataType.STRING, apply);
return Collections.<Feature>singletonList(new DocumentFeature(encoder, derivedField, "\\s+"));
}
use of org.jpmml.converter.Feature in project jpmml-r by jpmml.
the class NaiveBayesConverter method encodeModel.
@Override
public Model encodeModel(Schema schema) {
RGenericVector naiveBayes = getObject();
RIntegerVector apriori = naiveBayes.getIntegerElement("apriori");
RGenericVector tables = naiveBayes.getGenericElement("tables");
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
List<? extends Feature> features = schema.getFeatures();
BayesInputs bayesInputs = new BayesInputs();
for (int i = 0; i < features.size(); i++) {
Feature feature = features.get(i);
String name = feature.getName();
RDoubleVector table = tables.getDoubleElement(name);
RStringVector tableRows = table.dimnames(0);
RStringVector tableColumns = table.dimnames(1);
BayesInput bayesInput = new BayesInput(name, null, null);
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
for (int column = 0; column < tableColumns.size(); column++) {
TargetValueCounts targetValueCounts = new TargetValueCounts();
List<Double> probabilities = FortranMatrixUtil.getColumn(table.getValues(), tableRows.size(), tableColumns.size(), column);
for (int row = 0; row < tableRows.size(); row++) {
double count = apriori.getValue(row) * probabilities.get(row);
TargetValueCount targetValueCount = new TargetValueCount(tableRows.getValue(row), count);
targetValueCounts.addTargetValueCounts(targetValueCount);
}
PairCounts pairCounts = new PairCounts(tableColumns.getValue(column), targetValueCounts);
bayesInput.addPairCounts(pairCounts);
}
} else if (feature instanceof ContinuousFeature) {
ContinuousFeature continuousFeature = (ContinuousFeature) feature;
TargetValueStats targetValueStats = new TargetValueStats();
for (int row = 0; row < tableRows.size(); row++) {
List<Double> stats = FortranMatrixUtil.getRow(table.getValues(), tableRows.size(), 2, row);
double mean = stats.get(0);
double variance = Math.pow(stats.get(1), 2);
TargetValueStat targetValueStat = new TargetValueStat(tableRows.getValue(row), new GaussianDistribution(mean, variance));
targetValueStats.addTargetValueStats(targetValueStat);
}
bayesInput.setTargetValueStats(targetValueStats);
} else {
throw new IllegalArgumentException();
}
bayesInputs.addBayesInputs(bayesInput);
}
BayesOutput bayesOutput = new BayesOutput().setField(categoricalLabel.getName());
{
TargetValueCounts targetValueCounts = new TargetValueCounts();
RStringVector aprioriRows = apriori.dimnames(0);
for (int row = 0; row < aprioriRows.size(); row++) {
int count = apriori.getValue(row);
TargetValueCount targetValueCount = new TargetValueCount(aprioriRows.getValue(row), count);
targetValueCounts.addTargetValueCounts(targetValueCount);
}
bayesOutput.setTargetValueCounts(targetValueCounts);
}
NaiveBayesModel naiveBayesModel = new NaiveBayesModel(0d, MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), bayesInputs, bayesOutput).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
return naiveBayesModel;
}
use of org.jpmml.converter.Feature in project jpmml-r by jpmml.
the class FormulaUtil method addFeatures.
public static void addFeatures(Formula formula, List<String> names, boolean allowInteractions, RExpEncoder encoder) {
for (int i = 0; i < names.size(); i++) {
String name = names.get(i);
Feature feature;
if (allowInteractions) {
feature = formula.resolveComplexFeature(name);
} else {
feature = formula.resolveFeature(name);
}
encoder.addFeature(feature);
}
}
use of org.jpmml.converter.Feature in project jpmml-r by jpmml.
the class RandomForestConverter method encodeNode.
private <P extends Number> Node encodeNode(Predicate predicate, int i, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<? extends Number> bestvar, List<Double> xbestsplit, List<P> nodepred, CategoryManager categoryManager, Schema schema) {
Integer id = Integer.valueOf(i + 1);
int var = ValueUtil.asInt(bestvar.get(i));
if (var == 0) {
P prediction = nodepred.get(i);
Node result = new LeafNode(scoreEncoder.encode(prediction), predicate).setId(id);
return result;
}
CategoryManager leftCategoryManager = categoryManager;
CategoryManager rightCategoryManager = categoryManager;
Predicate leftPredicate;
Predicate rightPredicate;
Feature feature = schema.getFeature(var - 1);
Double split = xbestsplit.get(i);
if (feature instanceof BooleanFeature) {
BooleanFeature booleanFeature = (BooleanFeature) feature;
if (split != 0.5d) {
throw new IllegalArgumentException();
}
leftPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
rightPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
} else if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
String name = categoricalFeature.getName();
List<?> values = categoricalFeature.getValues();
java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
List<Object> leftValues = selectValues(values, valueFilter, split, true);
List<Object> rightValues = selectValues(values, valueFilter, split, false);
leftCategoryManager = categoryManager.fork(name, leftValues);
rightCategoryManager = categoryManager.fork(name, rightValues);
leftPredicate = createPredicate(categoricalFeature, leftValues);
rightPredicate = createPredicate(categoricalFeature, rightValues);
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, split);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, split);
}
Node result = new BranchNode(null, predicate).setId(id);
List<Node> nodes = result.getNodes();
int left = ValueUtil.asInt(leftDaughter.get(i));
if (left != 0) {
Node leftChild = encodeNode(leftPredicate, left - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, leftCategoryManager, schema);
nodes.add(leftChild);
}
int right = ValueUtil.asInt(rightDaughter.get(i));
if (right != 0) {
Node rightChild = encodeNode(rightPredicate, right - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, rightCategoryManager, schema);
nodes.add(rightChild);
}
return result;
}
Aggregations