use of org.jpmml.converter.Schema in project jpmml-r by jpmml.
the class GBMConverter method encodeModel.
@Override
public MiningModel encodeModel(Schema schema) {
RGenericVector gbm = getObject();
RDoubleVector initF = (RDoubleVector) gbm.getValue("initF");
RGenericVector trees = (RGenericVector) gbm.getValue("trees");
RGenericVector c_splits = (RGenericVector) gbm.getValue("c.splits");
RGenericVector distribution = (RGenericVector) gbm.getValue("distribution");
RStringVector distributionName = (RStringVector) distribution.getValue("name");
Schema segmentSchema = new Schema(new ContinuousLabel(null, DataType.DOUBLE), schema.getFeatures());
List<TreeModel> treeModels = new ArrayList<>();
for (int i = 0; i < trees.size(); i++) {
RGenericVector tree = (RGenericVector) trees.getValue(i);
TreeModel treeModel = encodeTreeModel(MiningFunction.REGRESSION, tree, c_splits, segmentSchema);
treeModels.add(treeModel);
}
MiningModel miningModel = encodeMiningModel(distributionName, treeModels, initF.asScalar(), schema);
return miningModel;
}
use of org.jpmml.converter.Schema in project jpmml-r by jpmml.
the class RangerConverter method encodeForest.
private List<TreeModel> encodeForest(RGenericVector forest, MiningFunction miningFunction, ScoreEncoder scoreEncoder, Schema schema) {
RNumberVector<?> numTrees = (RNumberVector<?>) forest.getValue("num.trees");
RGenericVector childNodeIDs = (RGenericVector) forest.getValue("child.nodeIDs");
RGenericVector splitVarIDs = (RGenericVector) forest.getValue("split.varIDs");
RGenericVector splitValues = (RGenericVector) forest.getValue("split.values");
RGenericVector terminalClassCounts = (RGenericVector) forest.getValue("terminal.class.counts", true);
Schema segmentSchema = schema.toAnonymousSchema();
List<TreeModel> treeModels = new ArrayList<>();
for (int i = 0; i < ValueUtil.asInt(numTrees.asScalar()); i++) {
TreeModel treeModel = encodeTreeModel(miningFunction, scoreEncoder, (RGenericVector) childNodeIDs.getValue(i), (RNumberVector<?>) splitVarIDs.getValue(i), (RNumberVector<?>) splitValues.getValue(i), (terminalClassCounts != null ? (RGenericVector) terminalClassCounts.getValue(i) : null), segmentSchema);
treeModels.add(treeModel);
}
return treeModels;
}
use of org.jpmml.converter.Schema in project jpmml-r by jpmml.
the class PreProcessEncoder method createSchema.
@Override
public Schema createSchema() {
Schema schema = super.createSchema();
schema = filter(schema);
return schema;
}
use of org.jpmml.converter.Schema in project jpmml-sparkml by jpmml.
the class ModelConverter method encodeSchema.
public Schema encodeSchema(SparkMLEncoder encoder) {
T model = getTransformer();
Label label = null;
if (model instanceof HasLabelCol) {
HasLabelCol hasLabelCol = (HasLabelCol) model;
String labelCol = hasLabelCol.getLabelCol();
Feature feature = encoder.getOnlyFeature(labelCol);
MiningFunction miningFunction = getMiningFunction();
switch(miningFunction) {
case CLASSIFICATION:
{
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
DataField dataField = encoder.getDataField(categoricalFeature.getName());
label = new CategoricalLabel(dataField);
} else if (feature instanceof ContinuousFeature) {
ContinuousFeature continuousFeature = (ContinuousFeature) feature;
int numClasses = 2;
if (model instanceof ClassificationModel) {
ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
numClasses = classificationModel.numClasses();
}
List<String> categories = new ArrayList<>();
for (int i = 0; i < numClasses; i++) {
categories.add(String.valueOf(i));
}
Field<?> field = encoder.toCategorical(continuousFeature.getName(), categories);
encoder.putOnlyFeature(labelCol, new CategoricalFeature(encoder, field, categories));
label = new CategoricalLabel(field.getName(), field.getDataType(), categories);
} else {
throw new IllegalArgumentException("Expected a categorical or categorical-like continuous feature, got " + feature);
}
}
break;
case REGRESSION:
{
Field<?> field = encoder.toContinuous(feature.getName());
field.setDataType(DataType.DOUBLE);
label = new ContinuousLabel(field.getName(), field.getDataType());
}
break;
default:
throw new IllegalArgumentException("Mining function " + miningFunction + " is not supported");
}
}
if (model instanceof ClassificationModel) {
ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
CategoricalLabel categoricalLabel = (CategoricalLabel) label;
int numClasses = classificationModel.numClasses();
if (numClasses != categoricalLabel.size()) {
throw new IllegalArgumentException("Expected " + numClasses + " target categories, got " + categoricalLabel.size() + " target categories");
}
}
String featuresCol = model.getFeaturesCol();
List<Feature> features = encoder.getFeatures(featuresCol);
if (model instanceof PredictionModel) {
PredictionModel<?, ?> predictionModel = (PredictionModel<?, ?>) model;
int numFeatures = predictionModel.numFeatures();
if (numFeatures != -1 && features.size() != numFeatures) {
throw new IllegalArgumentException("Expected " + numFeatures + " features, got " + features.size() + " features");
}
}
Schema result = new Schema(label, features);
return result;
}
use of org.jpmml.converter.Schema in project jpmml-r by jpmml.
the class GBMConverter method encodeBinaryClassification.
private MiningModel encodeBinaryClassification(List<TreeModel> treeModels, Double initF, double coefficient, Schema schema) {
Schema segmentSchema = new Schema(new ContinuousLabel(null, DataType.DOUBLE), schema.getFeatures());
MiningModel miningModel = createMiningModel(treeModels, initF, segmentSchema).setOutput(ModelUtil.createPredictedOutput(FieldName.create("gbmValue"), OpType.CONTINUOUS, DataType.DOUBLE));
return MiningModelUtil.createBinaryLogisticClassification(miningModel, -coefficient, 0d, RegressionModel.NormalizationMethod.LOGIT, true, schema);
}
Aggregations