use of org.dmg.pmml.True in project jpmml-r by jpmml.
the class ScorecardConverter method encodeModel.
@Override
public Scorecard encodeModel(Schema schema) {
RGenericVector glm = getObject();
RDoubleVector coefficients = (RDoubleVector) glm.getValue("coefficients");
RGenericVector family = (RGenericVector) glm.getValue("family");
RGenericVector scConf;
try {
scConf = (RGenericVector) glm.getValue("sc.conf");
} catch (IllegalArgumentException iae) {
throw new IllegalArgumentException("No scorecard configuration information. Please initialize the \'sc.conf\' element", iae);
}
Double intercept = coefficients.getValue(LMConverter.INTERCEPT, true);
List<? extends Feature> features = schema.getFeatures();
if (coefficients.size() != (features.size() + (intercept != null ? 1 : 0))) {
throw new IllegalArgumentException();
}
RNumberVector<?> odds = (RNumberVector<?>) scConf.getValue("odds");
RNumberVector<?> basePoints = (RNumberVector<?>) scConf.getValue("base_points");
RNumberVector<?> pdo = (RNumberVector<?>) scConf.getValue("pdo");
double factor = (pdo.asScalar()).doubleValue() / Math.log(2);
Map<FieldName, Characteristic> fieldCharacteristics = new LinkedHashMap<>();
for (Feature feature : features) {
FieldName name = feature.getName();
if (!(feature instanceof BinaryFeature)) {
throw new IllegalArgumentException();
}
Double coefficient = getFeatureCoefficient(feature, coefficients);
Characteristic characteristic = fieldCharacteristics.get(name);
if (characteristic == null) {
characteristic = new Characteristic().setName(FeatureUtil.createName("score", feature));
fieldCharacteristics.put(name, characteristic);
}
BinaryFeature binaryFeature = (BinaryFeature) feature;
SimplePredicate simplePredicate = new SimplePredicate().setField(binaryFeature.getName()).setOperator(SimplePredicate.Operator.EQUAL).setValue(binaryFeature.getValue());
Attribute attribute = new Attribute().setPartialScore(formatScore(-1d * coefficient * factor)).setPredicate(simplePredicate);
characteristic.addAttributes(attribute);
}
Characteristics characteristics = new Characteristics();
Collection<Map.Entry<FieldName, Characteristic>> entries = fieldCharacteristics.entrySet();
for (Map.Entry<FieldName, Characteristic> entry : entries) {
Characteristic characteristic = entry.getValue();
Attribute attribute = new Attribute().setPartialScore(0d).setPredicate(new True());
characteristic.addAttributes(attribute);
characteristics.addCharacteristics(characteristic);
}
Scorecard scorecard = new Scorecard(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), characteristics).setInitialScore(formatScore((basePoints.asScalar()).doubleValue() - Math.log((odds.asScalar()).doubleValue()) * factor - (intercept != null ? intercept * factor : 0))).setUseReasonCodes(false);
return scorecard;
}
use of org.dmg.pmml.True in project jpmml-r by jpmml.
the class RandomForestConverter method encodeTreeModel.
private <P extends Number> TreeModel encodeTreeModel(MiningFunction miningFunction, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<P> nodepred, List<? extends Number> bestvar, List<Double> xbestsplit, Schema schema) {
Node root = new Node().setId("1").setPredicate(new True());
encodeNode(root, 0, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, schema);
TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
return treeModel;
}
use of org.dmg.pmml.True in project jpmml-sparkml by jpmml.
the class TreeModelUtil method encodeTreeModel.
public static TreeModel encodeTreeModel(org.apache.spark.ml.tree.Node node, PredicateManager predicateManager, MiningFunction miningFunction, Schema schema) {
Node root = encodeNode(node, predicateManager, Collections.<FieldName, Set<String>>emptyMap(), miningFunction, schema).setPredicate(new True());
TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
String compact = TreeModelOptions.COMPACT;
if (compact != null && Boolean.valueOf(compact)) {
Visitor visitor = new TreeModelCompactor();
visitor.applyTo(treeModel);
}
return treeModel;
}
use of org.dmg.pmml.True in project jpmml-sparkml by jpmml.
the class TreeModelCompactor method handleNodePop.
private void handleNodePop(Node node) {
String score = node.getScore();
Predicate predicate = node.getPredicate();
if (predicate instanceof True) {
Node parentNode = getParentNode();
if (parentNode == null) {
return;
}
String parentScore = parentNode.getScore();
if (parentScore != null) {
throw new IllegalArgumentException();
}
if ((MiningFunction.REGRESSION).equals(this.miningFunction)) {
parentNode.setScore(score);
List<Node> parentChildren = parentNode.getNodes();
boolean success = parentChildren.remove(node);
if (!success) {
throw new IllegalArgumentException();
}
if (node.hasNodes()) {
List<Node> children = node.getNodes();
parentChildren.addAll(children);
}
} else if ((MiningFunction.CLASSIFICATION).equals(this.miningFunction)) {
if (node.hasNodes()) {
List<Node> parentChildren = parentNode.getNodes();
boolean success = parentChildren.remove(node);
if (!success) {
throw new IllegalArgumentException();
}
List<Node> children = node.getNodes();
parentChildren.addAll(children);
}
} else {
throw new IllegalArgumentException();
}
}
}
use of org.dmg.pmml.True in project jpmml-sparkml by jpmml.
the class TreeModelCompactor method handleNodePush.
private void handleNodePush(Node node) {
String id = node.getId();
String score = node.getScore();
if (id != null) {
throw new IllegalArgumentException();
}
if (node.hasNodes()) {
List<Node> children = node.getNodes();
if (children.size() != 2 || score != null) {
throw new IllegalArgumentException();
}
Node firstChild = children.get(0);
Node secondChild = children.get(1);
Predicate firstPredicate = firstChild.getPredicate();
Predicate secondPredicate = secondChild.getPredicate();
predicate: if (firstPredicate instanceof SimplePredicate && secondPredicate instanceof SimplePredicate) {
SimplePredicate firstSimplePredicate = (SimplePredicate) firstPredicate;
SimplePredicate secondSimplePredicate = (SimplePredicate) secondPredicate;
SimplePredicate.Operator firstOperator = firstSimplePredicate.getOperator();
SimplePredicate.Operator secondOperator = secondSimplePredicate.getOperator();
if (!(firstSimplePredicate.getField()).equals(secondSimplePredicate.getField())) {
throw new IllegalArgumentException();
}
if ((SimplePredicate.Operator.EQUAL).equals(firstOperator) && (SimplePredicate.Operator.EQUAL).equals(secondOperator)) {
if (!isCategoricalField(firstSimplePredicate.getField())) {
break predicate;
}
secondChild.setPredicate(new True());
} else {
if (!(firstSimplePredicate.getValue()).equals(secondSimplePredicate.getValue())) {
throw new IllegalArgumentException();
}
if ((SimplePredicate.Operator.NOT_EQUAL).equals(firstOperator) && (SimplePredicate.Operator.EQUAL).equals(secondOperator)) {
swapChildren(children);
firstChild = children.get(0);
secondChild = children.get(1);
} else if ((SimplePredicate.Operator.EQUAL).equals(firstOperator) && (SimplePredicate.Operator.NOT_EQUAL).equals(secondOperator)) {
// Ignored
} else if ((SimplePredicate.Operator.LESS_OR_EQUAL).equals(firstOperator) && (SimplePredicate.Operator.GREATER_THAN).equals(secondOperator)) {
// Ignored
} else {
throw new IllegalArgumentException();
}
secondChild.setPredicate(new True());
}
} else if (firstPredicate instanceof SimplePredicate && secondPredicate instanceof SimpleSetPredicate) {
SimplePredicate simplePredicate = (SimplePredicate) firstPredicate;
SimpleSetPredicate simpleSetPredicate = (SimpleSetPredicate) secondPredicate;
if (!(simplePredicate.getField()).equals(simpleSetPredicate.getField()) || !(SimplePredicate.Operator.EQUAL).equals(simplePredicate.getOperator()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(simpleSetPredicate.getBooleanOperator())) {
throw new IllegalArgumentException();
}
secondChild.setPredicate(addCategoricalField(simpleSetPredicate));
} else if (firstPredicate instanceof SimpleSetPredicate && secondPredicate instanceof SimplePredicate) {
SimpleSetPredicate simpleSetPredicate = (SimpleSetPredicate) firstPredicate;
SimplePredicate simplePredicate = (SimplePredicate) secondPredicate;
if (!(simpleSetPredicate.getField()).equals(simplePredicate.getField()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(simpleSetPredicate.getBooleanOperator()) || !(SimplePredicate.Operator.EQUAL).equals(simplePredicate.getOperator())) {
throw new IllegalArgumentException();
}
swapChildren(children);
firstChild = children.get(0);
secondChild = children.get(1);
secondChild.setPredicate(addCategoricalField(simpleSetPredicate));
} else if (firstPredicate instanceof SimpleSetPredicate && secondPredicate instanceof SimpleSetPredicate) {
SimpleSetPredicate firstSimpleSetPredicate = (SimpleSetPredicate) firstPredicate;
SimpleSetPredicate secondSimpleSetPredicate = (SimpleSetPredicate) secondPredicate;
if (!(firstSimpleSetPredicate.getField()).equals(secondSimpleSetPredicate.getField()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(firstSimpleSetPredicate.getBooleanOperator()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(secondSimpleSetPredicate.getBooleanOperator())) {
throw new IllegalArgumentException();
}
secondChild.setPredicate(addCategoricalField(secondSimpleSetPredicate));
} else {
throw new IllegalArgumentException();
}
} else {
if (score == null) {
throw new IllegalArgumentException();
}
}
}
Aggregations