use of org.dmg.pmml.tree.TreeModel in project jpmml-r by jpmml.
the class IForestConverter method encodeTreeModel.
private TreeModel encodeTreeModel(int index, RGenericVector trees, Schema schema) {
RIntegerVector nrnodes = trees.getIntegerElement("nrnodes");
RIntegerVector ntree = trees.getIntegerElement("ntree");
RIntegerVector nodeStatus = trees.getIntegerElement("nodeStatus");
RIntegerVector leftDaughter = trees.getIntegerElement("lDaughter");
RIntegerVector rightDaughter = trees.getIntegerElement("rDaughter");
RIntegerVector splitAtt = trees.getIntegerElement("splitAtt");
RDoubleVector splitPoint = trees.getDoubleElement("splitPoint");
RIntegerVector nSam = trees.getIntegerElement("nSam");
int rows = nrnodes.asScalar();
int columns = ntree.asScalar();
Node root = encodeNode(0, True.INSTANCE, 0, FortranMatrixUtil.getColumn(nodeStatus.getValues(), rows, columns, index), FortranMatrixUtil.getColumn(nSam.getValues(), rows, columns, index), FortranMatrixUtil.getColumn(leftDaughter.getValues(), rows, columns, index), FortranMatrixUtil.getColumn(rightDaughter.getValues(), rows, columns, index), FortranMatrixUtil.getColumn(splitAtt.getValues(), rows, columns, index), FortranMatrixUtil.getColumn(splitPoint.getValues(), rows, columns, index), schema);
TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
return treeModel;
}
use of org.dmg.pmml.tree.TreeModel in project jpmml-r by jpmml.
the class PartyConverter method encodeModel.
@Override
public Model encodeModel(Schema schema) {
RGenericVector party = getObject();
RGenericVector partyNode = party.getGenericElement("node");
RGenericVector predicted = DecorationUtil.getGenericElement(party, "predicted");
RVector<?> response = predicted.getVectorElement("(response)");
RDoubleVector prob = predicted.getDoubleElement("(prob)", false);
Node root = encodeNode(partyNode, True.INSTANCE, response, prob, schema);
TreeModel treeModel;
if (response instanceof RFactorVector) {
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
treeModel = new TreeModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), root).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
} else {
treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), root);
}
return treeModel;
}
use of org.dmg.pmml.tree.TreeModel in project jpmml-r by jpmml.
the class BinaryTreeConverter method encodeModel.
@Override
public TreeModel encodeModel(Schema schema) {
S4Object binaryTree = getObject();
RGenericVector tree = binaryTree.getGenericAttribute("tree");
Output output;
switch(this.miningFunction) {
case REGRESSION:
output = new Output();
break;
case CLASSIFICATION:
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
output = ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel);
break;
default:
throw new IllegalArgumentException();
}
output.addOutputFields(ModelUtil.createEntityIdField("nodeId", DataType.STRING));
TreeModel treeModel = encodeTreeModel(tree, schema).setOutput(output);
return treeModel;
}
use of org.dmg.pmml.tree.TreeModel in project jpmml-r by jpmml.
the class RPartEnsembleConverter method encodeTreeModels.
public List<TreeModel> encodeTreeModels(RGenericVector trees) {
List<TreeModel> result = new ArrayList<>();
if (trees.size() != this.schemas.size()) {
throw new IllegalArgumentException();
}
for (int i = 0; i < trees.size(); i++) {
RGenericVector tree = trees.getGenericValue(i);
Schema schema = this.schemas.get(i);
RPartConverter converter = this.converters.get(tree);
if (converter == null) {
throw new IllegalArgumentException();
}
Schema segmentSchema = schema.toAnonymousSchema();
TreeModel treeModel = (TreeModel) converter.encode(segmentSchema);
result.add(treeModel);
}
return result;
}
use of org.dmg.pmml.tree.TreeModel in project jpmml-r by jpmml.
the class RangerConverter method encodeTreeModel.
private TreeModel encodeTreeModel(MiningFunction miningFunction, ScoreEncoder scoreEncoder, RGenericVector childNodeIDs, RNumberVector<?> splitVarIDs, RNumberVector<?> splitValues, RGenericVector terminalClassCounts, Schema schema) {
RNumberVector<?> leftChildIDs = childNodeIDs.getNumericValue(0);
RNumberVector<?> rightChildIDs = childNodeIDs.getNumericValue(1);
Node root = encodeNode(True.INSTANCE, 0, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, new CategoryManager(), schema);
TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
return treeModel;
}
Aggregations