Search in sources :

Example 1 with BranchNode

use of org.dmg.pmml.tree.BranchNode in project jpmml-r by jpmml.

the class IForestConverter method encodeNode.

private Node encodeNode(int index, Predicate predicate, int depth, List<Integer> nodeStatus, List<Integer> nodeSize, List<Integer> leftDaughter, List<Integer> rightDaughter, List<Integer> splitAtt, List<Double> splitValue, Schema schema) {
    Integer id = Integer.valueOf(index + 1);
    int status = nodeStatus.get(index);
    int size = nodeSize.get(index);
    // Interior node
    if (status == -3) {
        int att = splitAtt.get(index);
        ContinuousFeature feature = (ContinuousFeature) schema.getFeature(att - 1);
        Double value = splitValue.get(index);
        Predicate leftPredicate = createSimplePredicate(feature, SimplePredicate.Operator.LESS_THAN, value);
        Predicate rightPredicate = createSimplePredicate(feature, SimplePredicate.Operator.GREATER_OR_EQUAL, value);
        Node leftChild = encodeNode(leftDaughter.get(index) - 1, leftPredicate, depth + 1, nodeStatus, nodeSize, leftDaughter, rightDaughter, splitAtt, splitValue, schema);
        Node rightChild = encodeNode(rightDaughter.get(index) - 1, rightPredicate, depth + 1, nodeStatus, nodeSize, leftDaughter, rightDaughter, splitAtt, splitValue, schema);
        Node result = new BranchNode(null, predicate).setId(id).addNodes(leftChild, rightChild);
        return result;
    } else // Terminal node
    if (status == -1) {
        Node result = new LeafNode(depth + avgPathLength(size), predicate).setId(id);
        return result;
    } else {
        throw new IllegalArgumentException();
    }
}
Also used : BranchNode(org.dmg.pmml.tree.BranchNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Node(org.dmg.pmml.tree.Node) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) LeafNode(org.dmg.pmml.tree.LeafNode) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate)

Example 2 with BranchNode

use of org.dmg.pmml.tree.BranchNode in project jpmml-r by jpmml.

the class PartyConverter method encodeNode.

private Node encodeNode(RGenericVector partyNode, Predicate predicate, RVector<?> response, RDoubleVector prob, Schema schema) {
    RIntegerVector id = partyNode.getIntegerElement("id");
    RGenericVector split = partyNode.getGenericElement("split");
    RGenericVector kids = partyNode.getGenericElement("kids");
    RGenericVector surrogates = partyNode.getGenericElement("surrogates");
    RGenericVector info = partyNode.getGenericElement("info");
    if (surrogates != null) {
        throw new IllegalArgumentException();
    }
    Label label = schema.getLabel();
    List<? extends Feature> features = schema.getFeatures();
    Node result;
    if (response instanceof RFactorVector) {
        result = new ClassifierNode(null, predicate);
    } else {
        if (kids == null) {
            result = new LeafNode(null, predicate);
        } else {
            result = new BranchNode(null, predicate);
        }
    }
    result.setId(Integer.valueOf(id.asScalar()));
    if (response instanceof RFactorVector) {
        RFactorVector factor = (RFactorVector) response;
        int index = id.asScalar() - 1;
        result.setScore(factor.getFactorValue(index));
        CategoricalLabel categoricalLabel = (CategoricalLabel) label;
        List<Double> probabilities = FortranMatrixUtil.getRow(prob.getValues(), response.size(), categoricalLabel.size(), index);
        List<ScoreDistribution> scoreDistributions = result.getScoreDistributions();
        for (int i = 0; i < categoricalLabel.size(); i++) {
            Object value = categoricalLabel.getValue(i);
            Double probability = probabilities.get(i);
            ScoreDistribution scoreDistribution = new ScoreDistribution(value, probability);
            scoreDistributions.add(scoreDistribution);
        }
    } else {
        result.setScore(response.getValue(id.asScalar() - 1));
    }
    if (kids == null) {
        return result;
    }
    RIntegerVector varid = split.getIntegerElement("varid");
    RDoubleVector breaks = split.getDoubleElement("breaks");
    RIntegerVector index = split.getIntegerElement("index");
    RBooleanVector right = split.getBooleanElement("right");
    Feature feature = features.get(varid.asScalar() - 1);
    if (breaks != null && index == null) {
        ContinuousFeature continuousFeature = (ContinuousFeature) feature;
        if (kids.size() != 2) {
            throw new IllegalArgumentException();
        }
        if (breaks.size() != 1) {
            throw new IllegalArgumentException();
        }
        Predicate leftPredicate;
        Predicate rightPredicate;
        Double value = breaks.asScalar();
        if (right.asScalar()) {
            leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
            rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
        } else {
            leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, value);
            rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, value);
        }
        Node leftChild = encodeNode(kids.getGenericValue(0), leftPredicate, response, prob, schema);
        Node rightChild = encodeNode(kids.getGenericValue(1), rightPredicate, response, prob, schema);
        result.addNodes(leftChild, rightChild);
    } else if (breaks == null && index != null) {
        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
        if (kids.size() < 2) {
            throw new IllegalArgumentException();
        }
        List<?> values = categoricalFeature.getValues();
        for (int i = 0; i < kids.size(); i++) {
            Predicate childPredicate;
            if (right.asScalar()) {
                childPredicate = createPredicate(categoricalFeature, selectValues(values, index, i + 1));
            } else {
                throw new IllegalArgumentException();
            }
            Node child = encodeNode(kids.getGenericValue(i), childPredicate, response, prob, schema);
            result.addNodes(child);
        }
    } else {
        throw new IllegalArgumentException();
    }
    return result;
}
Also used : BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) CategoricalLabel(org.jpmml.converter.CategoricalLabel) Label(org.jpmml.converter.Label) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) ScoreDistribution(org.dmg.pmml.ScoreDistribution) BranchNode(org.dmg.pmml.tree.BranchNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) CategoricalLabel(org.jpmml.converter.CategoricalLabel) LeafNode(org.dmg.pmml.tree.LeafNode) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) ArrayList(java.util.ArrayList) List(java.util.List)

Example 3 with BranchNode

use of org.dmg.pmml.tree.BranchNode in project jpmml-r by jpmml.

the class GBMConverter method encodeNode.

private Node encodeNode(int i, Predicate predicate, RGenericVector tree, RGenericVector c_splits, FlagManager flagManager, CategoryManager categoryManager, Schema schema) {
    Integer id = Integer.valueOf(i + 1);
    RIntegerVector splitVar = tree.getIntegerValue(0);
    RDoubleVector splitCodePred = tree.getDoubleValue(1);
    RIntegerVector leftNode = tree.getIntegerValue(2);
    RIntegerVector rightNode = tree.getIntegerValue(3);
    RIntegerVector missingNode = tree.getIntegerValue(4);
    RDoubleVector prediction = tree.getDoubleValue(7);
    Integer var = splitVar.getValue(i);
    if (var == -1) {
        Double value = prediction.getValue(i);
        Node result = new LeafNode(value, predicate).setId(id);
        return result;
    }
    Boolean isMissing;
    FlagManager missingFlagManager = flagManager;
    FlagManager nonMissingFlagManager = flagManager;
    Predicate missingPredicate;
    Feature feature = schema.getFeature(var);
    {
        String name = feature.getName();
        isMissing = flagManager.getValue(name);
        if (isMissing == null) {
            missingFlagManager = missingFlagManager.fork(name, Boolean.TRUE);
            nonMissingFlagManager = nonMissingFlagManager.fork(name, Boolean.FALSE);
        }
        missingPredicate = createSimplePredicate(feature, SimplePredicate.Operator.IS_MISSING, null);
    }
    CategoryManager leftCategoryManager = categoryManager;
    CategoryManager rightCategoryManager = categoryManager;
    Predicate leftPredicate;
    Predicate rightPredicate;
    Double split = splitCodePred.getValue(i);
    if (feature instanceof CategoricalFeature) {
        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
        String name = categoricalFeature.getName();
        List<?> values = categoricalFeature.getValues();
        int index = ValueUtil.asInt(split);
        RIntegerVector c_split = c_splits.getIntegerValue(index);
        List<Integer> splitValues = c_split.getValues();
        java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
        List<Object> leftValues = selectValues(values, valueFilter, splitValues, true);
        List<Object> rightValues = selectValues(values, valueFilter, splitValues, false);
        leftCategoryManager = leftCategoryManager.fork(name, leftValues);
        rightCategoryManager = rightCategoryManager.fork(name, rightValues);
        leftPredicate = createPredicate(categoricalFeature, leftValues);
        rightPredicate = createPredicate(categoricalFeature, rightValues);
    } else {
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, split);
        rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, split);
    }
    Node result = new BranchNode(null, predicate).setId(id);
    List<Node> nodes = result.getNodes();
    Integer missing = missingNode.getValue(i);
    if (missing != -1 && (isMissing == null || isMissing)) {
        Node missingChild = encodeNode(missing, missingPredicate, tree, c_splits, missingFlagManager, categoryManager, schema);
        nodes.add(missingChild);
    }
    Integer left = leftNode.getValue(i);
    if (left != -1 && (isMissing == null || !isMissing)) {
        Node leftChild = encodeNode(left, leftPredicate, tree, c_splits, nonMissingFlagManager, leftCategoryManager, schema);
        nodes.add(leftChild);
    }
    Integer right = rightNode.getValue(i);
    if (right != -1 && (isMissing == null || !isMissing)) {
        Node rightChild = encodeNode(right, rightPredicate, tree, c_splits, nonMissingFlagManager, rightCategoryManager, schema);
        nodes.add(rightChild);
    }
    return result;
}
Also used : Node(org.dmg.pmml.tree.Node) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) FlagManager(org.jpmml.converter.FlagManager) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) BranchNode(org.dmg.pmml.tree.BranchNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) LeafNode(org.dmg.pmml.tree.LeafNode) CategoryManager(org.jpmml.converter.CategoryManager)

Example 4 with BranchNode

use of org.dmg.pmml.tree.BranchNode in project jpmml-r by jpmml.

the class RangerConverter method encodeNode.

private Node encodeNode(Predicate predicate, int index, ScoreEncoder scoreEncoder, RNumberVector<?> leftChildIDs, RNumberVector<?> rightChildIDs, RNumberVector<?> splitVarIDs, RNumberVector<?> splitValues, RGenericVector terminalClassCounts, CategoryManager categoryManager, Schema schema) {
    int leftIndex = ValueUtil.asInt(leftChildIDs.getValue(index));
    int rightIndex = ValueUtil.asInt(rightChildIDs.getValue(index));
    Number splitValue = splitValues.getValue(index);
    RNumberVector<?> terminalClassCount = (terminalClassCounts != null ? terminalClassCounts.getNumericValue(index) : null);
    if (leftIndex == 0 && rightIndex == 0) {
        Node result = new LeafNode(null, predicate);
        return scoreEncoder.encode(result, splitValue, terminalClassCount);
    }
    CategoryManager leftCategoryManager = categoryManager;
    CategoryManager rightCategoryManager = categoryManager;
    Predicate leftPredicate;
    Predicate rightPredicate;
    int splitVarIndex = ValueUtil.asInt(splitVarIDs.getValue(index));
    Feature feature = schema.getFeature(this.hasDependentVar ? (splitVarIndex - 1) : splitVarIndex);
    if (feature instanceof CategoricalFeature) {
        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
        int splitLevelIndex = ValueUtil.asInt(Math.floor(splitValue.doubleValue()));
        String name = categoricalFeature.getName();
        List<?> values = categoricalFeature.getValues();
        java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
        List<Object> leftValues = filterValues(values.subList(0, splitLevelIndex), valueFilter);
        List<Object> rightValues = filterValues(values.subList(splitLevelIndex, values.size()), valueFilter);
        leftCategoryManager = leftCategoryManager.fork(name, leftValues);
        rightCategoryManager = rightCategoryManager.fork(name, rightValues);
        leftPredicate = createPredicate(categoricalFeature, leftValues);
        rightPredicate = createPredicate(categoricalFeature, rightValues);
    } else {
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, splitValue);
        rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, splitValue);
    }
    Node leftChild = encodeNode(leftPredicate, leftIndex, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, leftCategoryManager, schema);
    Node rightChild = encodeNode(rightPredicate, rightIndex, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, rightCategoryManager, schema);
    Node result = new BranchNode(null, predicate).addNodes(leftChild, rightChild);
    return result;
}
Also used : Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) BranchNode(org.dmg.pmml.tree.BranchNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) LeafNode(org.dmg.pmml.tree.LeafNode) CategoryManager(org.jpmml.converter.CategoryManager)

Example 5 with BranchNode

use of org.dmg.pmml.tree.BranchNode in project jpmml-r by jpmml.

the class RandomForestConverter method encodeNode.

private <P extends Number> Node encodeNode(Predicate predicate, int i, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<? extends Number> bestvar, List<Double> xbestsplit, List<P> nodepred, CategoryManager categoryManager, Schema schema) {
    Integer id = Integer.valueOf(i + 1);
    int var = ValueUtil.asInt(bestvar.get(i));
    if (var == 0) {
        P prediction = nodepred.get(i);
        Node result = new LeafNode(scoreEncoder.encode(prediction), predicate).setId(id);
        return result;
    }
    CategoryManager leftCategoryManager = categoryManager;
    CategoryManager rightCategoryManager = categoryManager;
    Predicate leftPredicate;
    Predicate rightPredicate;
    Feature feature = schema.getFeature(var - 1);
    Double split = xbestsplit.get(i);
    if (feature instanceof BooleanFeature) {
        BooleanFeature booleanFeature = (BooleanFeature) feature;
        if (split != 0.5d) {
            throw new IllegalArgumentException();
        }
        leftPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
        rightPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
    } else if (feature instanceof CategoricalFeature) {
        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
        String name = categoricalFeature.getName();
        List<?> values = categoricalFeature.getValues();
        java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
        List<Object> leftValues = selectValues(values, valueFilter, split, true);
        List<Object> rightValues = selectValues(values, valueFilter, split, false);
        leftCategoryManager = categoryManager.fork(name, leftValues);
        rightCategoryManager = categoryManager.fork(name, rightValues);
        leftPredicate = createPredicate(categoricalFeature, leftValues);
        rightPredicate = createPredicate(categoricalFeature, rightValues);
    } else {
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, split);
        rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, split);
    }
    Node result = new BranchNode(null, predicate).setId(id);
    List<Node> nodes = result.getNodes();
    int left = ValueUtil.asInt(leftDaughter.get(i));
    if (left != 0) {
        Node leftChild = encodeNode(leftPredicate, left - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, leftCategoryManager, schema);
        nodes.add(leftChild);
    }
    int right = ValueUtil.asInt(rightDaughter.get(i));
    if (right != 0) {
        Node rightChild = encodeNode(rightPredicate, right - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, rightCategoryManager, schema);
        nodes.add(rightChild);
    }
    return result;
}
Also used : Node(org.dmg.pmml.tree.Node) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) BranchNode(org.dmg.pmml.tree.BranchNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) LeafNode(org.dmg.pmml.tree.LeafNode) ArrayList(java.util.ArrayList) List(java.util.List) CategoryManager(org.jpmml.converter.CategoryManager)

Aggregations

Predicate (org.dmg.pmml.Predicate)6 SimplePredicate (org.dmg.pmml.SimplePredicate)6 BranchNode (org.dmg.pmml.tree.BranchNode)6 LeafNode (org.dmg.pmml.tree.LeafNode)6 Node (org.dmg.pmml.tree.Node)6 ContinuousFeature (org.jpmml.converter.ContinuousFeature)6 CategoricalFeature (org.jpmml.converter.CategoricalFeature)5 Feature (org.jpmml.converter.Feature)5 ArrayList (java.util.ArrayList)3 List (java.util.List)3 ClassifierNode (org.dmg.pmml.tree.ClassifierNode)3 CategoryManager (org.jpmml.converter.CategoryManager)3 ScoreDistribution (org.dmg.pmml.ScoreDistribution)1 BooleanFeature (org.jpmml.converter.BooleanFeature)1 CategoricalLabel (org.jpmml.converter.CategoricalLabel)1 FlagManager (org.jpmml.converter.FlagManager)1 Label (org.jpmml.converter.Label)1