Search in sources :

Example 11 with True

use of org.dmg.pmml.True in project jpmml-r by jpmml.

the class ScorecardConverter method encodeModel.

@Override
public Scorecard encodeModel(Schema schema) {
    RGenericVector glm = getObject();
    RDoubleVector coefficients = (RDoubleVector) glm.getValue("coefficients");
    RGenericVector family = (RGenericVector) glm.getValue("family");
    RGenericVector scConf;
    try {
        scConf = (RGenericVector) glm.getValue("sc.conf");
    } catch (IllegalArgumentException iae) {
        throw new IllegalArgumentException("No scorecard configuration information. Please initialize the \'sc.conf\' element", iae);
    }
    Double intercept = coefficients.getValue(LMConverter.INTERCEPT, true);
    List<? extends Feature> features = schema.getFeatures();
    if (coefficients.size() != (features.size() + (intercept != null ? 1 : 0))) {
        throw new IllegalArgumentException();
    }
    RNumberVector<?> odds = (RNumberVector<?>) scConf.getValue("odds");
    RNumberVector<?> basePoints = (RNumberVector<?>) scConf.getValue("base_points");
    RNumberVector<?> pdo = (RNumberVector<?>) scConf.getValue("pdo");
    double factor = (pdo.asScalar()).doubleValue() / Math.log(2);
    Map<FieldName, Characteristic> fieldCharacteristics = new LinkedHashMap<>();
    for (Feature feature : features) {
        FieldName name = feature.getName();
        if (!(feature instanceof BinaryFeature)) {
            throw new IllegalArgumentException();
        }
        Double coefficient = getFeatureCoefficient(feature, coefficients);
        Characteristic characteristic = fieldCharacteristics.get(name);
        if (characteristic == null) {
            characteristic = new Characteristic().setName(FeatureUtil.createName("score", feature));
            fieldCharacteristics.put(name, characteristic);
        }
        BinaryFeature binaryFeature = (BinaryFeature) feature;
        SimplePredicate simplePredicate = new SimplePredicate().setField(binaryFeature.getName()).setOperator(SimplePredicate.Operator.EQUAL).setValue(binaryFeature.getValue());
        Attribute attribute = new Attribute().setPartialScore(formatScore(-1d * coefficient * factor)).setPredicate(simplePredicate);
        characteristic.addAttributes(attribute);
    }
    Characteristics characteristics = new Characteristics();
    Collection<Map.Entry<FieldName, Characteristic>> entries = fieldCharacteristics.entrySet();
    for (Map.Entry<FieldName, Characteristic> entry : entries) {
        Characteristic characteristic = entry.getValue();
        Attribute attribute = new Attribute().setPartialScore(0d).setPredicate(new True());
        characteristic.addAttributes(attribute);
        characteristics.addCharacteristics(characteristic);
    }
    Scorecard scorecard = new Scorecard(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), characteristics).setInitialScore(formatScore((basePoints.asScalar()).doubleValue() - Math.log((odds.asScalar()).doubleValue()) * factor - (intercept != null ? intercept * factor : 0))).setUseReasonCodes(false);
    return scorecard;
}
Also used : Attribute(org.dmg.pmml.scorecard.Attribute) Characteristic(org.dmg.pmml.scorecard.Characteristic) True(org.dmg.pmml.True) BinaryFeature(org.jpmml.converter.BinaryFeature) Feature(org.jpmml.converter.Feature) BinaryFeature(org.jpmml.converter.BinaryFeature) SimplePredicate(org.dmg.pmml.SimplePredicate) LinkedHashMap(java.util.LinkedHashMap) Characteristics(org.dmg.pmml.scorecard.Characteristics) Scorecard(org.dmg.pmml.scorecard.Scorecard) FieldName(org.dmg.pmml.FieldName) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map)

Example 12 with True

use of org.dmg.pmml.True in project jpmml-r by jpmml.

the class RandomForestConverter method encodeTreeModel.

private <P extends Number> TreeModel encodeTreeModel(MiningFunction miningFunction, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<P> nodepred, List<? extends Number> bestvar, List<Double> xbestsplit, Schema schema) {
    Node root = new Node().setId("1").setPredicate(new True());
    encodeNode(root, 0, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, schema);
    TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
    return treeModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) Node(org.dmg.pmml.tree.Node) True(org.dmg.pmml.True)

Example 13 with True

use of org.dmg.pmml.True in project jpmml-sparkml by jpmml.

the class TreeModelUtil method encodeTreeModel.

public static TreeModel encodeTreeModel(org.apache.spark.ml.tree.Node node, PredicateManager predicateManager, MiningFunction miningFunction, Schema schema) {
    Node root = encodeNode(node, predicateManager, Collections.<FieldName, Set<String>>emptyMap(), miningFunction, schema).setPredicate(new True());
    TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
    String compact = TreeModelOptions.COMPACT;
    if (compact != null && Boolean.valueOf(compact)) {
        Visitor visitor = new TreeModelCompactor();
        visitor.applyTo(treeModel);
    }
    return treeModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) DecisionTreeModel(org.apache.spark.ml.tree.DecisionTreeModel) HashSet(java.util.HashSet) Set(java.util.Set) Visitor(org.dmg.pmml.Visitor) InternalNode(org.apache.spark.ml.tree.InternalNode) Node(org.dmg.pmml.tree.Node) LeafNode(org.apache.spark.ml.tree.LeafNode) True(org.dmg.pmml.True) TreeModelCompactor(org.jpmml.sparkml.visitors.TreeModelCompactor) FieldName(org.dmg.pmml.FieldName)

Example 14 with True

use of org.dmg.pmml.True in project jpmml-sparkml by jpmml.

the class TreeModelCompactor method handleNodePop.

private void handleNodePop(Node node) {
    String score = node.getScore();
    Predicate predicate = node.getPredicate();
    if (predicate instanceof True) {
        Node parentNode = getParentNode();
        if (parentNode == null) {
            return;
        }
        String parentScore = parentNode.getScore();
        if (parentScore != null) {
            throw new IllegalArgumentException();
        }
        if ((MiningFunction.REGRESSION).equals(this.miningFunction)) {
            parentNode.setScore(score);
            List<Node> parentChildren = parentNode.getNodes();
            boolean success = parentChildren.remove(node);
            if (!success) {
                throw new IllegalArgumentException();
            }
            if (node.hasNodes()) {
                List<Node> children = node.getNodes();
                parentChildren.addAll(children);
            }
        } else if ((MiningFunction.CLASSIFICATION).equals(this.miningFunction)) {
            if (node.hasNodes()) {
                List<Node> parentChildren = parentNode.getNodes();
                boolean success = parentChildren.remove(node);
                if (!success) {
                    throw new IllegalArgumentException();
                }
                List<Node> children = node.getNodes();
                parentChildren.addAll(children);
            }
        } else {
            throw new IllegalArgumentException();
        }
    }
}
Also used : Node(org.dmg.pmml.tree.Node) True(org.dmg.pmml.True) List(java.util.List) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate)

Example 15 with True

use of org.dmg.pmml.True in project jpmml-sparkml by jpmml.

the class TreeModelCompactor method handleNodePush.

private void handleNodePush(Node node) {
    String id = node.getId();
    String score = node.getScore();
    if (id != null) {
        throw new IllegalArgumentException();
    }
    if (node.hasNodes()) {
        List<Node> children = node.getNodes();
        if (children.size() != 2 || score != null) {
            throw new IllegalArgumentException();
        }
        Node firstChild = children.get(0);
        Node secondChild = children.get(1);
        Predicate firstPredicate = firstChild.getPredicate();
        Predicate secondPredicate = secondChild.getPredicate();
        predicate: if (firstPredicate instanceof SimplePredicate && secondPredicate instanceof SimplePredicate) {
            SimplePredicate firstSimplePredicate = (SimplePredicate) firstPredicate;
            SimplePredicate secondSimplePredicate = (SimplePredicate) secondPredicate;
            SimplePredicate.Operator firstOperator = firstSimplePredicate.getOperator();
            SimplePredicate.Operator secondOperator = secondSimplePredicate.getOperator();
            if (!(firstSimplePredicate.getField()).equals(secondSimplePredicate.getField())) {
                throw new IllegalArgumentException();
            }
            if ((SimplePredicate.Operator.EQUAL).equals(firstOperator) && (SimplePredicate.Operator.EQUAL).equals(secondOperator)) {
                if (!isCategoricalField(firstSimplePredicate.getField())) {
                    break predicate;
                }
                secondChild.setPredicate(new True());
            } else {
                if (!(firstSimplePredicate.getValue()).equals(secondSimplePredicate.getValue())) {
                    throw new IllegalArgumentException();
                }
                if ((SimplePredicate.Operator.NOT_EQUAL).equals(firstOperator) && (SimplePredicate.Operator.EQUAL).equals(secondOperator)) {
                    swapChildren(children);
                    firstChild = children.get(0);
                    secondChild = children.get(1);
                } else if ((SimplePredicate.Operator.EQUAL).equals(firstOperator) && (SimplePredicate.Operator.NOT_EQUAL).equals(secondOperator)) {
                // Ignored
                } else if ((SimplePredicate.Operator.LESS_OR_EQUAL).equals(firstOperator) && (SimplePredicate.Operator.GREATER_THAN).equals(secondOperator)) {
                // Ignored
                } else {
                    throw new IllegalArgumentException();
                }
                secondChild.setPredicate(new True());
            }
        } else if (firstPredicate instanceof SimplePredicate && secondPredicate instanceof SimpleSetPredicate) {
            SimplePredicate simplePredicate = (SimplePredicate) firstPredicate;
            SimpleSetPredicate simpleSetPredicate = (SimpleSetPredicate) secondPredicate;
            if (!(simplePredicate.getField()).equals(simpleSetPredicate.getField()) || !(SimplePredicate.Operator.EQUAL).equals(simplePredicate.getOperator()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(simpleSetPredicate.getBooleanOperator())) {
                throw new IllegalArgumentException();
            }
            secondChild.setPredicate(addCategoricalField(simpleSetPredicate));
        } else if (firstPredicate instanceof SimpleSetPredicate && secondPredicate instanceof SimplePredicate) {
            SimpleSetPredicate simpleSetPredicate = (SimpleSetPredicate) firstPredicate;
            SimplePredicate simplePredicate = (SimplePredicate) secondPredicate;
            if (!(simpleSetPredicate.getField()).equals(simplePredicate.getField()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(simpleSetPredicate.getBooleanOperator()) || !(SimplePredicate.Operator.EQUAL).equals(simplePredicate.getOperator())) {
                throw new IllegalArgumentException();
            }
            swapChildren(children);
            firstChild = children.get(0);
            secondChild = children.get(1);
            secondChild.setPredicate(addCategoricalField(simpleSetPredicate));
        } else if (firstPredicate instanceof SimpleSetPredicate && secondPredicate instanceof SimpleSetPredicate) {
            SimpleSetPredicate firstSimpleSetPredicate = (SimpleSetPredicate) firstPredicate;
            SimpleSetPredicate secondSimpleSetPredicate = (SimpleSetPredicate) secondPredicate;
            if (!(firstSimpleSetPredicate.getField()).equals(secondSimpleSetPredicate.getField()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(firstSimpleSetPredicate.getBooleanOperator()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(secondSimpleSetPredicate.getBooleanOperator())) {
                throw new IllegalArgumentException();
            }
            secondChild.setPredicate(addCategoricalField(secondSimpleSetPredicate));
        } else {
            throw new IllegalArgumentException();
        }
    } else {
        if (score == null) {
            throw new IllegalArgumentException();
        }
    }
}
Also used : Node(org.dmg.pmml.tree.Node) True(org.dmg.pmml.True) SimplePredicate(org.dmg.pmml.SimplePredicate) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate)

Aggregations

True (org.dmg.pmml.True)21 Node (org.dmg.pmml.tree.Node)9 TreeModel (org.dmg.pmml.tree.TreeModel)7 SimplePredicate (org.dmg.pmml.SimplePredicate)5 SimpleSetPredicate (org.dmg.pmml.SimpleSetPredicate)4 Test (org.junit.Test)4 ArrayList (java.util.ArrayList)3 FieldName (org.dmg.pmml.FieldName)3 Predicate (org.dmg.pmml.Predicate)3 MiningModel (org.dmg.pmml.mining.MiningModel)3 Segment (org.dmg.pmml.mining.Segment)3 Segmentation (org.dmg.pmml.mining.Segmentation)3 List (java.util.List)2 Model (org.dmg.pmml.Model)2 BlockStmt (com.github.javaparser.ast.stmt.BlockStmt)1 Statement (com.github.javaparser.ast.stmt.Statement)1 IOException (java.io.IOException)1 HashSet (java.util.HashSet)1 LinkedHashMap (java.util.LinkedHashMap)1 Map (java.util.Map)1