use of org.dmg.pmml.DataField in project jpmml-r by jpmml.
the class RandomForestConverter method encodeFormula.
private void encodeFormula(RExpEncoder encoder) {
RGenericVector randomForest = getObject();
RGenericVector forest = (RGenericVector) randomForest.getValue("forest");
RNumberVector<?> y = (RNumberVector<?>) randomForest.getValue("y", true);
RExp terms = randomForest.getValue("terms");
final RNumberVector<?> ncat = (RNumberVector<?>) forest.getValue("ncat");
final RGenericVector xlevels = (RGenericVector) forest.getValue("xlevels");
RIntegerVector response = (RIntegerVector) terms.getAttributeValue("response");
FormulaContext context = new FormulaContext() {
@Override
public List<String> getCategories(String variable) {
if (ncat != null && ncat.hasValue(variable)) {
if ((ncat.getValue(variable)).doubleValue() > 1d) {
RStringVector levels = (RStringVector) xlevels.getValue(variable);
return levels.getValues();
}
}
return null;
}
@Override
public RGenericVector getData() {
return null;
}
};
Formula formula = FormulaUtil.createFormula(terms, context, encoder);
// Dependent variable
int responseIndex = response.asScalar();
if (responseIndex != 0) {
DataField dataField = (DataField) formula.getField(responseIndex - 1);
if (y instanceof RIntegerVector) {
dataField = (DataField) encoder.toCategorical(dataField.getName(), RExpUtil.getFactorLevels(y));
}
encoder.setLabel(dataField);
} else {
throw new IllegalArgumentException();
}
RStringVector xlevelNames = xlevels.names();
// Independent variables
for (int i = 0; i < xlevelNames.size(); i++) {
String xlevelName = xlevelNames.getValue(i);
Feature feature = formula.resolveFeature(FieldName.create(xlevelName));
encoder.addFeature(feature);
}
}
use of org.dmg.pmml.DataField in project jpmml-r by jpmml.
the class RExpEncoder method addFeature.
public void addFeature(Field<?> field) {
Feature feature;
OpType opType = field.getOpType();
switch(opType) {
case CATEGORICAL:
feature = new CategoricalFeature(this, (DataField) field);
break;
case CONTINUOUS:
feature = new ContinuousFeature(this, field);
break;
default:
throw new IllegalArgumentException();
}
addFeature(feature);
}
use of org.dmg.pmml.DataField in project jpmml-sparkml by jpmml.
the class ImputerModelConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
ImputerModel transformer = getTransformer();
Double missingValue = transformer.getMissingValue();
String strategy = transformer.getStrategy();
Dataset<Row> surrogateDF = transformer.surrogateDF();
String[] inputCols = transformer.getInputCols();
String[] outputCols = transformer.getOutputCols();
if (inputCols.length != outputCols.length) {
throw new IllegalArgumentException();
}
MissingValueTreatmentMethod missingValueTreatmentMethod = parseStrategy(strategy);
List<Row> surrogateRows = surrogateDF.collectAsList();
if (surrogateRows.size() != 1) {
throw new IllegalArgumentException();
}
Row surrogateRow = surrogateRows.get(0);
List<Feature> result = new ArrayList<>();
for (int i = 0; i < inputCols.length; i++) {
String inputCol = inputCols[i];
String outputCol = outputCols[i];
Feature feature = encoder.getOnlyFeature(inputCol);
Field<?> field = encoder.getField(feature.getName());
if (field instanceof DataField) {
DataField dataField = (DataField) field;
Object surrogate = surrogateRow.getAs(inputCol);
MissingValueDecorator missingValueDecorator = new MissingValueDecorator().setMissingValueReplacement(ValueUtil.formatValue(surrogate)).setMissingValueTreatment(missingValueTreatmentMethod);
if (missingValue != null && !missingValue.isNaN()) {
missingValueDecorator.addValues(ValueUtil.formatValue(missingValue));
}
encoder.addDecorator(feature.getName(), missingValueDecorator);
} else {
throw new IllegalArgumentException();
}
result.add(feature);
}
return result;
}
use of org.dmg.pmml.DataField in project jpmml-sparkml by jpmml.
the class IndexToStringConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
IndexToString transformer = getTransformer();
DataField dataField = encoder.createDataField(formatName(transformer), OpType.CATEGORICAL, DataType.STRING, Arrays.asList(transformer.getLabels()));
return Collections.<Feature>singletonList(new CategoricalFeature(encoder, dataField));
}
use of org.dmg.pmml.DataField in project jpmml-sparkml by jpmml.
the class StringIndexerModelConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
StringIndexerModel transformer = getTransformer();
Feature feature = encoder.getOnlyFeature(transformer.getInputCol());
List<String> categories = new ArrayList<>();
categories.addAll(Arrays.asList(transformer.labels()));
String handleInvalid = transformer.getHandleInvalid();
Field<?> field = encoder.toCategorical(feature.getName(), categories);
if (field instanceof DataField) {
DataField dataField = (DataField) field;
InvalidValueTreatmentMethod invalidValueTreatmentMethod;
switch(handleInvalid) {
case "keep":
invalidValueTreatmentMethod = InvalidValueTreatmentMethod.AS_IS;
break;
case "error":
invalidValueTreatmentMethod = InvalidValueTreatmentMethod.RETURN_INVALID;
break;
default:
throw new IllegalArgumentException(handleInvalid);
}
InvalidValueDecorator invalidValueDecorator = new InvalidValueDecorator().setInvalidValueTreatment(invalidValueTreatmentMethod);
encoder.addDecorator(dataField.getName(), invalidValueDecorator);
} else if (field instanceof DerivedField) {
// Ignored
} else {
throw new IllegalArgumentException();
}
switch(handleInvalid) {
case "keep":
Apply setApply = PMMLUtil.createApply("isIn", feature.ref());
for (String category : categories) {
setApply.addExpressions(PMMLUtil.createConstant(category, feature.getDataType()));
}
categories.add(StringIndexerModelConverter.LABEL_UNKNOWN);
Apply apply = PMMLUtil.createApply("if", setApply, feature.ref(), PMMLUtil.createConstant(StringIndexerModelConverter.LABEL_UNKNOWN, DataType.STRING));
field = encoder.createDerivedField(FeatureUtil.createName("handleInvalid", feature), OpType.CATEGORICAL, feature.getDataType(), apply);
break;
default:
break;
}
return Collections.<Feature>singletonList(new CategoricalFeature(encoder, field, categories));
}
Aggregations