use of org.dmg.pmml.mining.MiningModel in project jpmml-r by jpmml.
the class RandomForestConverter method encodeClassification.
private MiningModel encodeClassification(RGenericVector forest, final Schema schema) {
RNumberVector<?> bestvar = (RNumberVector<?>) forest.getValue("bestvar");
RNumberVector<?> treemap = (RNumberVector<?>) forest.getValue("treemap");
RIntegerVector nodepred = (RIntegerVector) forest.getValue("nodepred");
RDoubleVector xbestsplit = (RDoubleVector) forest.getValue("xbestsplit");
RIntegerVector nrnodes = (RIntegerVector) forest.getValue("nrnodes");
RDoubleVector ntree = (RDoubleVector) forest.getValue("ntree");
int rows = nrnodes.asScalar();
int columns = ValueUtil.asInt(ntree.asScalar());
final CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
ScoreEncoder<Integer> scoreEncoder = new ScoreEncoder<Integer>() {
@Override
public String encode(Integer value) {
return categoricalLabel.getValue(value - 1);
}
};
Schema segmentSchema = schema.toAnonymousSchema();
List<TreeModel> treeModels = new ArrayList<>();
for (int i = 0; i < columns; i++) {
List<? extends Number> daughters = FortranMatrixUtil.getColumn(treemap.getValues(), 2 * rows, columns, i);
TreeModel treeModel = encodeTreeModel(MiningFunction.CLASSIFICATION, scoreEncoder, FortranMatrixUtil.getColumn(daughters, rows, 2, 0), FortranMatrixUtil.getColumn(daughters, rows, 2, 1), FortranMatrixUtil.getColumn(nodepred.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(bestvar.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(xbestsplit.getValues(), rows, columns, i), segmentSchema);
treeModels.add(treeModel);
}
MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel)).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.MAJORITY_VOTE, treeModels)).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
return miningModel;
}
use of org.dmg.pmml.mining.MiningModel in project jpmml-r by jpmml.
the class RandomForestConverter method encodeRegression.
private MiningModel encodeRegression(RGenericVector forest, final Schema schema) {
RNumberVector<?> leftDaughter = (RNumberVector<?>) forest.getValue("leftDaughter");
RNumberVector<?> rightDaughter = (RNumberVector<?>) forest.getValue("rightDaughter");
RDoubleVector nodepred = (RDoubleVector) forest.getValue("nodepred");
RNumberVector<?> bestvar = (RNumberVector<?>) forest.getValue("bestvar");
RDoubleVector xbestsplit = (RDoubleVector) forest.getValue("xbestsplit");
RIntegerVector nrnodes = (RIntegerVector) forest.getValue("nrnodes");
RDoubleVector ntree = (RDoubleVector) forest.getValue("ntree");
ScoreEncoder<Double> scoreEncoder = new ScoreEncoder<Double>() {
@Override
public String encode(Double value) {
return ValueUtil.formatValue(value);
}
};
int rows = nrnodes.asScalar();
int columns = ValueUtil.asInt(ntree.asScalar());
Schema segmentSchema = schema.toAnonymousSchema();
List<TreeModel> treeModels = new ArrayList<>();
for (int i = 0; i < columns; i++) {
TreeModel treeModel = encodeTreeModel(MiningFunction.REGRESSION, scoreEncoder, FortranMatrixUtil.getColumn(leftDaughter.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(rightDaughter.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(nodepred.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(bestvar.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(xbestsplit.getValues(), rows, columns, i), segmentSchema);
treeModels.add(treeModel);
}
MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels));
return miningModel;
}
use of org.dmg.pmml.mining.MiningModel in project jpmml-sparkml by jpmml.
the class RandomForestClassificationModelConverter method encodeModel.
@Override
public MiningModel encodeModel(Schema schema) {
RandomForestClassificationModel model = getTransformer();
List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(model, schema);
MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels));
return miningModel;
}
use of org.dmg.pmml.mining.MiningModel in project jpmml-sparkml by jpmml.
the class RandomForestRegressionModelConverter method encodeModel.
@Override
public MiningModel encodeModel(Schema schema) {
RandomForestRegressionModel model = getTransformer();
List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(model, schema);
MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels));
return miningModel;
}
use of org.dmg.pmml.mining.MiningModel in project jpmml-sparkml by jpmml.
the class ConverterUtil method toPMML.
public static PMML toPMML(StructType schema, PipelineModel pipelineModel) {
checkVersion();
SparkMLEncoder encoder = new SparkMLEncoder(schema);
List<org.dmg.pmml.Model> models = new ArrayList<>();
Iterable<Transformer> transformers = getTransformers(pipelineModel);
for (Transformer transformer : transformers) {
TransformerConverter<?> converter = ConverterUtil.createConverter(transformer);
if (converter instanceof FeatureConverter) {
FeatureConverter<?> featureConverter = (FeatureConverter<?>) converter;
featureConverter.registerFeatures(encoder);
} else if (converter instanceof ModelConverter) {
ModelConverter<?> modelConverter = (ModelConverter<?>) converter;
org.dmg.pmml.Model model = modelConverter.registerModel(encoder);
models.add(model);
} else {
throw new IllegalArgumentException("Expected a " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + " instance, got " + converter);
}
}
org.dmg.pmml.Model rootModel;
if (models.size() == 1) {
rootModel = Iterables.getOnlyElement(models);
} else if (models.size() > 1) {
List<MiningField> targetMiningFields = new ArrayList<>();
for (org.dmg.pmml.Model model : models) {
MiningSchema miningSchema = model.getMiningSchema();
List<MiningField> miningFields = miningSchema.getMiningFields();
for (MiningField miningField : miningFields) {
MiningField.UsageType usageType = miningField.getUsageType();
switch(usageType) {
case PREDICTED:
case TARGET:
targetMiningFields.add(miningField);
break;
default:
break;
}
}
}
MiningSchema miningSchema = new MiningSchema(targetMiningFields);
MiningModel miningModel = MiningModelUtil.createModelChain(models, new Schema(null, Collections.<Feature>emptyList())).setMiningSchema(miningSchema);
rootModel = miningModel;
} else {
throw new IllegalArgumentException("Expected a pipeline with one or more models, got a pipeline with zero models");
}
PMML pmml = encoder.encodePMML(rootModel);
return pmml;
}
Aggregations