use of org.dmg.pmml.DataField in project jpmml-r by jpmml.
the class RangerConverter method encodeSchema.
@Override
public void encodeSchema(RExpEncoder encoder) {
RGenericVector ranger = getObject();
RGenericVector forest;
try {
forest = (RGenericVector) ranger.getValue("forest");
} catch (IllegalArgumentException iae) {
throw new IllegalArgumentException("No forest information. Please initialize the \'forest\' element", iae);
}
RGenericVector variableLevels;
try {
variableLevels = (RGenericVector) ranger.getValue("variable.levels");
} catch (IllegalArgumentException iae) {
throw new IllegalArgumentException("No variable levels information. Please initialize the \'variable.levels\' element", iae);
}
RStringVector treeType = (RStringVector) ranger.getValue("treetype");
// Dependent variable
{
FieldName name = FieldName.create("_target");
DataField dataField;
switch(treeType.asScalar()) {
case "Regression":
{
dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.DOUBLE);
}
break;
case "Classification":
case "Probability estimation":
{
RStringVector levels = (RStringVector) forest.getValue("levels");
dataField = encoder.createDataField(name, OpType.CATEGORICAL, null, levels.getValues());
}
break;
default:
throw new IllegalArgumentException();
}
encoder.setLabel(dataField);
}
RBooleanVector isOrdered = (RBooleanVector) forest.getValue("is.ordered");
RStringVector independentVariableNames = (RStringVector) forest.getValue("independent.variable.names");
// Independent variables
for (int i = 0; i < independentVariableNames.size(); i++) {
if (!isOrdered.getValue(i + 1)) {
throw new IllegalArgumentException();
}
String independentVariableName = independentVariableNames.getValue(i);
FieldName name = FieldName.create(independentVariableName);
DataField dataField;
if (variableLevels.hasValue(independentVariableName)) {
RStringVector levels = (RStringVector) variableLevels.getValue(independentVariableName);
dataField = encoder.createDataField(name, OpType.CATEGORICAL, DataType.STRING, levels.getValues());
} else {
dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.DOUBLE);
}
encoder.addFeature(dataField);
}
}
use of org.dmg.pmml.DataField in project jpmml-r by jpmml.
the class BinaryTreeConverter method encodeResponse.
private void encodeResponse(S4Object responses, RExpEncoder encoder) {
RGenericVector variables = (RGenericVector) responses.getAttributeValue("variables");
RBooleanVector is_nominal = (RBooleanVector) responses.getAttributeValue("is_nominal");
RGenericVector levels = (RGenericVector) responses.getAttributeValue("levels");
RStringVector variableNames = variables.names();
String variableName = variableNames.asScalar();
DataField dataField;
Boolean categorical = is_nominal.getValue(variableName);
if ((Boolean.TRUE).equals(categorical)) {
this.miningFunction = MiningFunction.CLASSIFICATION;
RExp targetVariable = variables.getValue(variableName);
RStringVector targetVariableClass = (RStringVector) targetVariable.getAttributeValue("class");
RStringVector targetCategories = (RStringVector) levels.getValue(variableName);
dataField = encoder.createDataField(FieldName.create(variableName), OpType.CATEGORICAL, RExpUtil.getDataType(targetVariableClass.asScalar()), targetCategories.getValues());
} else if ((Boolean.FALSE).equals(categorical)) {
this.miningFunction = MiningFunction.REGRESSION;
dataField = encoder.createDataField(FieldName.create(variableName), OpType.CONTINUOUS, DataType.DOUBLE);
} else {
throw new IllegalArgumentException();
}
encoder.setLabel(dataField);
}
use of org.dmg.pmml.DataField in project jpmml-r by jpmml.
the class ElmNNConverter method encodeSchema.
@Override
public void encodeSchema(RExpEncoder encoder) {
RGenericVector elmNN = getObject();
final RGenericVector model;
try {
model = (RGenericVector) elmNN.getValue("model");
} catch (IllegalArgumentException iae) {
throw new IllegalArgumentException("No model frame information. Please initialize the \'model\' element", iae);
}
RExp terms = model.getAttributeValue("terms");
RIntegerVector response = (RIntegerVector) terms.getAttributeValue("response");
RStringVector columns = (RStringVector) terms.getAttributeValue("columns");
FormulaContext context = new ModelFrameFormulaContext(model);
Formula formula = FormulaUtil.createFormula(terms, context, encoder);
// Dependent variable
int responseIndex = response.asScalar();
if (responseIndex != 0) {
DataField dataField = (DataField) formula.getField(responseIndex - 1);
encoder.setLabel(dataField);
}
// Independent variables
for (int i = 0; i < columns.size(); i++) {
String column = columns.getValue(i);
if (i == 0 && "(Intercept)".equals(column)) {
continue;
}
Feature feature = formula.resolveFeature(column);
encoder.addFeature(feature);
}
}
use of org.dmg.pmml.DataField in project jpmml-r by jpmml.
the class RandomForestConverter method encodeNonFormula.
private void encodeNonFormula(RExpEncoder encoder) {
RGenericVector randomForest = getObject();
RGenericVector forest = (RGenericVector) randomForest.getValue("forest");
RNumberVector<?> y = (RNumberVector<?>) randomForest.getValue("y", true);
RStringVector xNames = (RStringVector) randomForest.getValue("xNames", true);
RNumberVector<?> ncat = (RNumberVector<?>) forest.getValue("ncat");
RGenericVector xlevels = (RGenericVector) forest.getValue("xlevels");
if (xNames == null) {
xNames = xlevels.names();
}
// Dependent variable
{
FieldName name = FieldName.create("_target");
DataField dataField;
if (y instanceof RIntegerVector) {
dataField = encoder.createDataField(name, OpType.CATEGORICAL, null, RExpUtil.getFactorLevels(y));
} else {
dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.DOUBLE);
}
encoder.setLabel(dataField);
}
// Independernt variables
for (int i = 0; i < ncat.size(); i++) {
FieldName name = FieldName.create(xNames.getValue(i));
DataField dataField;
boolean categorical = ((ncat.getValue(i)).doubleValue() > 1d);
if (categorical) {
RStringVector levels = (RStringVector) xlevels.getValue(i);
dataField = encoder.createDataField(name, OpType.CATEGORICAL, null, levels.getValues());
} else {
dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.DOUBLE);
}
encoder.addFeature(dataField);
}
}
use of org.dmg.pmml.DataField in project jpmml-sparkml by jpmml.
the class ModelConverter method encodeSchema.
public Schema encodeSchema(SparkMLEncoder encoder) {
T model = getTransformer();
Label label = null;
if (model instanceof HasLabelCol) {
HasLabelCol hasLabelCol = (HasLabelCol) model;
String labelCol = hasLabelCol.getLabelCol();
Feature feature = encoder.getOnlyFeature(labelCol);
MiningFunction miningFunction = getMiningFunction();
switch(miningFunction) {
case CLASSIFICATION:
{
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
DataField dataField = encoder.getDataField(categoricalFeature.getName());
label = new CategoricalLabel(dataField);
} else if (feature instanceof ContinuousFeature) {
ContinuousFeature continuousFeature = (ContinuousFeature) feature;
int numClasses = 2;
if (model instanceof ClassificationModel) {
ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
numClasses = classificationModel.numClasses();
}
List<String> categories = new ArrayList<>();
for (int i = 0; i < numClasses; i++) {
categories.add(String.valueOf(i));
}
Field<?> field = encoder.toCategorical(continuousFeature.getName(), categories);
encoder.putOnlyFeature(labelCol, new CategoricalFeature(encoder, field, categories));
label = new CategoricalLabel(field.getName(), field.getDataType(), categories);
} else {
throw new IllegalArgumentException("Expected a categorical or categorical-like continuous feature, got " + feature);
}
}
break;
case REGRESSION:
{
Field<?> field = encoder.toContinuous(feature.getName());
field.setDataType(DataType.DOUBLE);
label = new ContinuousLabel(field.getName(), field.getDataType());
}
break;
default:
throw new IllegalArgumentException("Mining function " + miningFunction + " is not supported");
}
}
if (model instanceof ClassificationModel) {
ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
CategoricalLabel categoricalLabel = (CategoricalLabel) label;
int numClasses = classificationModel.numClasses();
if (numClasses != categoricalLabel.size()) {
throw new IllegalArgumentException("Expected " + numClasses + " target categories, got " + categoricalLabel.size() + " target categories");
}
}
String featuresCol = model.getFeaturesCol();
List<Feature> features = encoder.getFeatures(featuresCol);
if (model instanceof PredictionModel) {
PredictionModel<?, ?> predictionModel = (PredictionModel<?, ?>) model;
int numFeatures = predictionModel.numFeatures();
if (numFeatures != -1 && features.size() != numFeatures) {
throw new IllegalArgumentException("Expected " + numFeatures + " features, got " + features.size() + " features");
}
}
Schema result = new Schema(label, features);
return result;
}
Aggregations