use of org.dmg.pmml.Array in project shifu by ShifuML.
the class TreeNodePmmlElementCreator method convert.
public org.dmg.pmml.tree.Node convert(Node node, boolean isLeft, Split split) {
org.dmg.pmml.tree.Node pmmlNode = new org.dmg.pmml.tree.Node();
pmmlNode.setId(String.valueOf(node.getId()));
if (node.getPredict() != null) {
pmmlNode.setScore(String.valueOf(treeModel.isClassification() ? node.getPredict().getClassValue() : node.getPredict().getPredict()));
}
pmmlNode.setDefaultChild(null);
Predicate predicate = null;
ColumnConfig columnConfig = this.columnConfigList.get(split.getColumnNum());
if (columnConfig.isNumerical()) {
SimplePredicate p = new SimplePredicate();
p.setValue(String.valueOf(split.getThreshold()));
// TODO, how to support segment variable in tree model, here should be changed
p.setField(new FieldName(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
if (isLeft) {
p.setOperator(SimplePredicate.Operator.fromValue("lessThan"));
} else {
p.setOperator(SimplePredicate.Operator.fromValue("greaterOrEqual"));
}
predicate = p;
} else if (columnConfig.isCategorical()) {
SimpleSetPredicate p = new SimpleSetPredicate();
Set<Short> childCategories = split.getLeftOrRightCategories();
// TODO, how to support segment variable in tree model, here should be changed
p.setField(new FieldName(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
StringBuilder arrayStr = new StringBuilder();
List<String> valueList = treeModel.getCategoricalColumnNameNames().get(columnConfig.getColumnNum());
for (Short sh : childCategories) {
if (sh >= valueList.size()) {
arrayStr.append(" \"\"");
continue;
}
String s = valueList.get(sh);
arrayStr.append(" ");
if (s.contains("\"")) {
String tmp = s.replaceAll("\"", "\\\\\\\"");
if (s.contains(" ")) {
arrayStr.append("\"");
arrayStr.append(tmp);
arrayStr.append("\"");
} else {
arrayStr.append(tmp);
}
} else {
if (s.contains(" ")) {
arrayStr.append("\"");
arrayStr.append(s);
arrayStr.append("\"");
} else {
arrayStr.append(s);
}
}
}
Array array = new Array(Array.Type.fromValue("string"), arrayStr.toString().trim());
p.setArray(array);
if (isLeft) {
if (split.isLeft()) {
p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isIn"));
} else {
p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isNotIn"));
}
} else {
if (split.isLeft()) {
p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isNotIn"));
} else {
p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isIn"));
}
}
predicate = p;
}
pmmlNode.setPredicate(predicate);
if (node.getSplit() == null || node.isRealLeaf()) {
return pmmlNode;
}
List<org.dmg.pmml.tree.Node> childList = pmmlNode.getNodes();
org.dmg.pmml.tree.Node left = convert(node.getLeft(), true, node.getSplit());
org.dmg.pmml.tree.Node right = convert(node.getRight(), false, node.getSplit());
childList.add(left);
childList.add(right);
return pmmlNode;
}
use of org.dmg.pmml.Array in project shifu by ShifuML.
the class ModelStatsCreator method build.
@Override
public ModelStats build(BasicML basicML) {
ModelStats modelStats = new ModelStats();
if (basicML instanceof BasicFloatNetwork) {
BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
Set<Integer> featureSet = bfn.getFeatureSet();
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.isFinalSelect() && (CollectionUtils.isEmpty(featureSet) || featureSet.contains(columnConfig.getColumnNum()))) {
UnivariateStats univariateStats = new UnivariateStats();
// here, no need to consider if column is in segment expansion
// as we need to address new stats variable
// set simple column name in PMML
univariateStats.setField(FieldName.create(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
if (columnConfig.isCategorical()) {
DiscrStats discrStats = new DiscrStats();
Array countArray = createCountArray(columnConfig);
discrStats.addArrays(countArray);
if (!isConcise) {
List<Extension> extensions = createExtensions(columnConfig);
discrStats.addExtensions(extensions.toArray(new Extension[extensions.size()]));
}
univariateStats.setDiscrStats(discrStats);
} else {
// numerical column
univariateStats.setNumericInfo(createNumericInfo(columnConfig));
if (!isConcise) {
univariateStats.setContStats(createConStats(columnConfig));
}
}
modelStats.addUnivariateStats(univariateStats);
}
}
} else {
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.isFinalSelect()) {
UnivariateStats univariateStats = new UnivariateStats();
// here, no need to consider if column is in segment expansion as we need to address new stats
// variable
univariateStats.setField(FieldName.create(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
if (columnConfig.isCategorical()) {
DiscrStats discrStats = new DiscrStats();
Array countArray = createCountArray(columnConfig);
discrStats.addArrays(countArray);
if (!isConcise) {
List<Extension> extensions = createExtensions(columnConfig);
discrStats.addExtensions(extensions.toArray(new Extension[extensions.size()]));
}
univariateStats.setDiscrStats(discrStats);
} else {
// numerical column
univariateStats.setNumericInfo(createNumericInfo(columnConfig));
if (!isConcise) {
univariateStats.setContStats(createConStats(columnConfig));
}
}
modelStats.addUnivariateStats(univariateStats);
}
}
}
return modelStats;
}
use of org.dmg.pmml.Array in project drools by kiegroup.
the class ModelUtilsTest method getObjectsFromArray.
@Test
public void getObjectsFromArray() {
List<String> values = Arrays.asList("32", "11", "43");
Array array = getArray(Array.Type.INT, values);
List<Object> retrieved = ModelUtils.getObjectsFromArray(array);
assertEquals(values.size(), retrieved.size());
for (int i = 0; i < values.size(); i++) {
Object obj = retrieved.get(i);
assertTrue(obj instanceof Integer);
Integer expected = Integer.valueOf(values.get(i));
assertEquals(expected, obj);
}
values = Arrays.asList("just", "11", "fun");
array = getArray(Array.Type.STRING, values);
retrieved = ModelUtils.getObjectsFromArray(array);
assertEquals(values.size(), retrieved.size());
for (int i = 0; i < values.size(); i++) {
Object obj = retrieved.get(i);
assertTrue(obj instanceof String);
assertEquals(values.get(i), obj);
}
values = Arrays.asList("23.11", "11", "123.123");
array = getArray(Array.Type.REAL, values);
retrieved = ModelUtils.getObjectsFromArray(array);
assertEquals(values.size(), retrieved.size());
for (int i = 0; i < values.size(); i++) {
Object obj = retrieved.get(i);
assertTrue(obj instanceof Double);
Double expected = Double.valueOf(values.get(i));
assertEquals(expected, obj);
}
}
use of org.dmg.pmml.Array in project drools by kiegroup.
the class PMMLModelTestUtils method getArray.
public static Array getArray(Array.Type arrayType, final List<String> values) {
String arrayString = String.join(" ", values);
Array toReturn = new Array(arrayType, arrayString);
toReturn.setN(values.size());
return toReturn;
}
use of org.dmg.pmml.Array in project drools by kiegroup.
the class PMMLModelTestUtils method getSimpleSetPredicate.
public static SimpleSetPredicate getSimpleSetPredicate(final String predicateName, final Array.Type arrayType, final List<String> values, final SimpleSetPredicate.BooleanOperator booleanOperator) {
FieldName fieldName = FieldName.create(predicateName);
SimpleSetPredicate toReturn = new SimpleSetPredicate();
toReturn.setField(fieldName);
toReturn.setBooleanOperator(booleanOperator);
Array array = getArray(arrayType, values);
toReturn.setArray(array);
return toReturn;
}
Aggregations