use of org.dmg.pmml.Array in project shifu by ShifuML.
the class TreeNodePmmlElementCreator method convert.
public org.dmg.pmml.tree.Node convert(Node node, boolean isLeft, Split split) {
org.dmg.pmml.tree.Node pmmlNode = new org.dmg.pmml.tree.Node();
pmmlNode.setId(String.valueOf(node.getId()));
if (node.getPredict() != null) {
pmmlNode.setScore(String.valueOf(treeModel.isClassification() ? node.getPredict().getClassValue() : node.getPredict().getPredict()));
}
pmmlNode.setDefaultChild(null);
Predicate predicate = null;
ColumnConfig columnConfig = this.columnConfigList.get(split.getColumnNum());
if (columnConfig.isNumerical()) {
SimplePredicate p = new SimplePredicate();
p.setValue(String.valueOf(split.getThreshold()));
// TODO, how to support segment variable in tree model, here should be changed
p.setField(new FieldName(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
if (isLeft) {
p.setOperator(SimplePredicate.Operator.fromValue("lessThan"));
} else {
p.setOperator(SimplePredicate.Operator.fromValue("greaterOrEqual"));
}
predicate = p;
} else if (columnConfig.isCategorical()) {
SimpleSetPredicate p = new SimpleSetPredicate();
Set<Short> childCategories = split.getLeftOrRightCategories();
// TODO, how to support segment variable in tree model, here should be changed
p.setField(new FieldName(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
StringBuilder arrayStr = new StringBuilder();
List<String> valueList = treeModel.getCategoricalColumnNameNames().get(columnConfig.getColumnNum());
for (Short sh : childCategories) {
if (sh >= valueList.size()) {
arrayStr.append(" \"\"");
continue;
}
String s = valueList.get(sh);
arrayStr.append(" ");
if (s.contains("\"")) {
String tmp = s.replaceAll("\"", "\\\\\\\"");
if (s.contains(" ")) {
arrayStr.append("\"");
arrayStr.append(tmp);
arrayStr.append("\"");
} else {
arrayStr.append(tmp);
}
} else {
if (s.contains(" ")) {
arrayStr.append("\"");
arrayStr.append(s);
arrayStr.append("\"");
} else {
arrayStr.append(s);
}
}
}
Array array = new Array(Array.Type.fromValue("string"), arrayStr.toString().trim());
p.setArray(array);
if (isLeft) {
if (split.isLeft()) {
p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isIn"));
} else {
p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isNotIn"));
}
} else {
if (split.isLeft()) {
p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isNotIn"));
} else {
p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isIn"));
}
}
predicate = p;
}
pmmlNode.setPredicate(predicate);
if (node.getSplit() == null || node.isRealLeaf()) {
return pmmlNode;
}
List<org.dmg.pmml.tree.Node> childList = pmmlNode.getNodes();
org.dmg.pmml.tree.Node left = convert(node.getLeft(), true, node.getSplit());
org.dmg.pmml.tree.Node right = convert(node.getRight(), false, node.getSplit());
childList.add(left);
childList.add(right);
return pmmlNode;
}
use of org.dmg.pmml.Array in project shifu by ShifuML.
the class ModelStatsCreator method build.
@Override
public ModelStats build(BasicML basicML) {
ModelStats modelStats = new ModelStats();
if (basicML instanceof BasicFloatNetwork) {
BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
Set<Integer> featureSet = bfn.getFeatureSet();
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.isFinalSelect() && (CollectionUtils.isEmpty(featureSet) || featureSet.contains(columnConfig.getColumnNum()))) {
UnivariateStats univariateStats = new UnivariateStats();
// here, no need to consider if column is in segment expansion
// as we need to address new stats variable
// set simple column name in PMML
univariateStats.setField(FieldName.create(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
if (columnConfig.isCategorical()) {
DiscrStats discrStats = new DiscrStats();
Array countArray = createCountArray(columnConfig);
discrStats.addArrays(countArray);
if (!isConcise) {
List<Extension> extensions = createExtensions(columnConfig);
discrStats.addExtensions(extensions.toArray(new Extension[extensions.size()]));
}
univariateStats.setDiscrStats(discrStats);
} else {
// numerical column
univariateStats.setNumericInfo(createNumericInfo(columnConfig));
if (!isConcise) {
univariateStats.setContStats(createConStats(columnConfig));
}
}
modelStats.addUnivariateStats(univariateStats);
}
}
} else {
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.isFinalSelect()) {
UnivariateStats univariateStats = new UnivariateStats();
// here, no need to consider if column is in segment expansion as we need to address new stats
// variable
univariateStats.setField(FieldName.create(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
if (columnConfig.isCategorical()) {
DiscrStats discrStats = new DiscrStats();
Array countArray = createCountArray(columnConfig);
discrStats.addArrays(countArray);
if (!isConcise) {
List<Extension> extensions = createExtensions(columnConfig);
discrStats.addExtensions(extensions.toArray(new Extension[extensions.size()]));
}
univariateStats.setDiscrStats(discrStats);
} else {
// numerical column
univariateStats.setNumericInfo(createNumericInfo(columnConfig));
if (!isConcise) {
univariateStats.setContStats(createConStats(columnConfig));
}
}
modelStats.addUnivariateStats(univariateStats);
}
}
}
return modelStats;
}
use of org.dmg.pmml.Array in project drools by kiegroup.
the class PMMLModelTestUtils method getRandomSimpleSetPredicate.
public static SimpleSetPredicate getRandomSimpleSetPredicate() {
FieldName fieldName = FieldName.create(RandomStringUtils.random(6, true, false));
SimpleSetPredicate toReturn = new SimpleSetPredicate();
toReturn.setField(fieldName);
toReturn.setBooleanOperator(getRandomSimpleSetPredicateOperator());
Array.Type arrayType = getRandomArrayType();
List<String> values = getStringObjects(arrayType, 3);
Array array = getArray(arrayType, values);
toReturn.setArray(array);
return toReturn;
}
use of org.dmg.pmml.Array in project drools by kiegroup.
the class PMMLModelTestUtils method getRandomSimpleSetPredicate.
public static SimpleSetPredicate getRandomSimpleSetPredicate(DataField dataField) {
SimpleSetPredicate toReturn = getRandomSimpleSetPredicate();
toReturn.setField(dataField.getName());
toReturn.setBooleanOperator(getRandomSimpleSetPredicateOperator());
Array.Type arrayType = getArrayType(dataField.getDataType());
List<String> values = getStringObjects(arrayType, 3);
Array array = getArray(arrayType, values);
toReturn.setArray(array);
return toReturn;
}
use of org.dmg.pmml.Array in project shifu by ShifuML.
the class ModelStatsCreator method createCountArray.
/**
* Create @Array for numerical variable
*
* @param columnConfig
* - ColumnConfig for numerical variable
* @return Array for numerical variable ( positive count + negative count )
*/
private Array createCountArray(ColumnConfig columnConfig) {
Array countAllArray = new Array();
List<Integer> binCountAll = new ArrayList<Integer>(columnConfig.getBinCountPos().size());
for (int i = 0; i < binCountAll.size(); i++) {
binCountAll.add(columnConfig.getBinCountPos().get(i) + columnConfig.getBinCountNeg().get(i));
}
countAllArray.setType(Array.Type.INT);
countAllArray.setN(binCountAll.size());
countAllArray.setValue(StringUtils.join(binCountAll, ' '));
return countAllArray;
}
Aggregations