Search in sources :

Example 6 with Array

use of org.dmg.pmml.Array in project shifu by ShifuML.

the class TreeNodePmmlElementCreator method convert.

public org.dmg.pmml.tree.Node convert(Node node, boolean isLeft, Split split) {
    org.dmg.pmml.tree.Node pmmlNode = new org.dmg.pmml.tree.Node();
    pmmlNode.setId(String.valueOf(node.getId()));
    if (node.getPredict() != null) {
        pmmlNode.setScore(String.valueOf(treeModel.isClassification() ? node.getPredict().getClassValue() : node.getPredict().getPredict()));
    }
    pmmlNode.setDefaultChild(null);
    Predicate predicate = null;
    ColumnConfig columnConfig = this.columnConfigList.get(split.getColumnNum());
    if (columnConfig.isNumerical()) {
        SimplePredicate p = new SimplePredicate();
        p.setValue(String.valueOf(split.getThreshold()));
        // TODO, how to support segment variable in tree model, here should be changed
        p.setField(new FieldName(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
        if (isLeft) {
            p.setOperator(SimplePredicate.Operator.fromValue("lessThan"));
        } else {
            p.setOperator(SimplePredicate.Operator.fromValue("greaterOrEqual"));
        }
        predicate = p;
    } else if (columnConfig.isCategorical()) {
        SimpleSetPredicate p = new SimpleSetPredicate();
        Set<Short> childCategories = split.getLeftOrRightCategories();
        // TODO, how to support segment variable in tree model, here should be changed
        p.setField(new FieldName(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
        StringBuilder arrayStr = new StringBuilder();
        List<String> valueList = treeModel.getCategoricalColumnNameNames().get(columnConfig.getColumnNum());
        for (Short sh : childCategories) {
            if (sh >= valueList.size()) {
                arrayStr.append(" \"\"");
                continue;
            }
            String s = valueList.get(sh);
            arrayStr.append(" ");
            if (s.contains("\"")) {
                String tmp = s.replaceAll("\"", "\\\\\\\"");
                if (s.contains(" ")) {
                    arrayStr.append("\"");
                    arrayStr.append(tmp);
                    arrayStr.append("\"");
                } else {
                    arrayStr.append(tmp);
                }
            } else {
                if (s.contains(" ")) {
                    arrayStr.append("\"");
                    arrayStr.append(s);
                    arrayStr.append("\"");
                } else {
                    arrayStr.append(s);
                }
            }
        }
        Array array = new Array(Array.Type.fromValue("string"), arrayStr.toString().trim());
        p.setArray(array);
        if (isLeft) {
            if (split.isLeft()) {
                p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isIn"));
            } else {
                p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isNotIn"));
            }
        } else {
            if (split.isLeft()) {
                p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isNotIn"));
            } else {
                p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isIn"));
            }
        }
        predicate = p;
    }
    pmmlNode.setPredicate(predicate);
    if (node.getSplit() == null || node.isRealLeaf()) {
        return pmmlNode;
    }
    List<org.dmg.pmml.tree.Node> childList = pmmlNode.getNodes();
    org.dmg.pmml.tree.Node left = convert(node.getLeft(), true, node.getSplit());
    org.dmg.pmml.tree.Node right = convert(node.getRight(), false, node.getSplit());
    childList.add(left);
    childList.add(right);
    return pmmlNode;
}
Also used : Set(java.util.Set) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) Node(ml.shifu.shifu.core.dtrain.dt.Node) SimplePredicate(org.dmg.pmml.SimplePredicate) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate) Array(org.dmg.pmml.Array) List(java.util.List) FieldName(org.dmg.pmml.FieldName)

Example 7 with Array

use of org.dmg.pmml.Array in project shifu by ShifuML.

the class ModelStatsCreator method build.

@Override
public ModelStats build(BasicML basicML) {
    ModelStats modelStats = new ModelStats();
    if (basicML instanceof BasicFloatNetwork) {
        BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
        Set<Integer> featureSet = bfn.getFeatureSet();
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.isFinalSelect() && (CollectionUtils.isEmpty(featureSet) || featureSet.contains(columnConfig.getColumnNum()))) {
                UnivariateStats univariateStats = new UnivariateStats();
                // here, no need to consider if column is in segment expansion
                // as we need to address new stats variable
                // set simple column name in PMML
                univariateStats.setField(FieldName.create(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
                if (columnConfig.isCategorical()) {
                    DiscrStats discrStats = new DiscrStats();
                    Array countArray = createCountArray(columnConfig);
                    discrStats.addArrays(countArray);
                    if (!isConcise) {
                        List<Extension> extensions = createExtensions(columnConfig);
                        discrStats.addExtensions(extensions.toArray(new Extension[extensions.size()]));
                    }
                    univariateStats.setDiscrStats(discrStats);
                } else {
                    // numerical column
                    univariateStats.setNumericInfo(createNumericInfo(columnConfig));
                    if (!isConcise) {
                        univariateStats.setContStats(createConStats(columnConfig));
                    }
                }
                modelStats.addUnivariateStats(univariateStats);
            }
        }
    } else {
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.isFinalSelect()) {
                UnivariateStats univariateStats = new UnivariateStats();
                // here, no need to consider if column is in segment expansion as we need to address new stats
                // variable
                univariateStats.setField(FieldName.create(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
                if (columnConfig.isCategorical()) {
                    DiscrStats discrStats = new DiscrStats();
                    Array countArray = createCountArray(columnConfig);
                    discrStats.addArrays(countArray);
                    if (!isConcise) {
                        List<Extension> extensions = createExtensions(columnConfig);
                        discrStats.addExtensions(extensions.toArray(new Extension[extensions.size()]));
                    }
                    univariateStats.setDiscrStats(discrStats);
                } else {
                    // numerical column
                    univariateStats.setNumericInfo(createNumericInfo(columnConfig));
                    if (!isConcise) {
                        univariateStats.setContStats(createConStats(columnConfig));
                    }
                }
                modelStats.addUnivariateStats(univariateStats);
            }
        }
    }
    return modelStats;
}
Also used : Array(org.dmg.pmml.Array) Extension(org.dmg.pmml.Extension) DiscrStats(org.dmg.pmml.DiscrStats) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) UnivariateStats(org.dmg.pmml.UnivariateStats) ModelStats(org.dmg.pmml.ModelStats) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)

Example 8 with Array

use of org.dmg.pmml.Array in project drools by kiegroup.

the class PMMLModelTestUtils method getRandomSimpleSetPredicate.

public static SimpleSetPredicate getRandomSimpleSetPredicate() {
    FieldName fieldName = FieldName.create(RandomStringUtils.random(6, true, false));
    SimpleSetPredicate toReturn = new SimpleSetPredicate();
    toReturn.setField(fieldName);
    toReturn.setBooleanOperator(getRandomSimpleSetPredicateOperator());
    Array.Type arrayType = getRandomArrayType();
    List<String> values = getStringObjects(arrayType, 3);
    Array array = getArray(arrayType, values);
    toReturn.setArray(array);
    return toReturn;
}
Also used : Array(org.dmg.pmml.Array) FieldName(org.dmg.pmml.FieldName) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate)

Example 9 with Array

use of org.dmg.pmml.Array in project drools by kiegroup.

the class PMMLModelTestUtils method getRandomSimpleSetPredicate.

public static SimpleSetPredicate getRandomSimpleSetPredicate(DataField dataField) {
    SimpleSetPredicate toReturn = getRandomSimpleSetPredicate();
    toReturn.setField(dataField.getName());
    toReturn.setBooleanOperator(getRandomSimpleSetPredicateOperator());
    Array.Type arrayType = getArrayType(dataField.getDataType());
    List<String> values = getStringObjects(arrayType, 3);
    Array array = getArray(arrayType, values);
    toReturn.setArray(array);
    return toReturn;
}
Also used : Array(org.dmg.pmml.Array) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate)

Example 10 with Array

use of org.dmg.pmml.Array in project shifu by ShifuML.

the class ModelStatsCreator method createCountArray.

/**
 * Create @Array for numerical variable
 *
 * @param columnConfig
 *            - ColumnConfig for numerical variable
 * @return Array for numerical variable ( positive count + negative count )
 */
private Array createCountArray(ColumnConfig columnConfig) {
    Array countAllArray = new Array();
    List<Integer> binCountAll = new ArrayList<Integer>(columnConfig.getBinCountPos().size());
    for (int i = 0; i < binCountAll.size(); i++) {
        binCountAll.add(columnConfig.getBinCountPos().get(i) + columnConfig.getBinCountNeg().get(i));
    }
    countAllArray.setType(Array.Type.INT);
    countAllArray.setN(binCountAll.size());
    countAllArray.setValue(StringUtils.join(binCountAll, ' '));
    return countAllArray;
}
Also used : Array(org.dmg.pmml.Array) ArrayList(java.util.ArrayList)

Aggregations

Array (org.dmg.pmml.Array)10 SimpleSetPredicate (org.dmg.pmml.SimpleSetPredicate)5 FieldName (org.dmg.pmml.FieldName)3 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)2 PMMLModelTestUtils.getArray (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getArray)2 ArrayList (java.util.ArrayList)1 List (java.util.List)1 Set (java.util.Set)1 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)1 Node (ml.shifu.shifu.core.dtrain.dt.Node)1 DiscrStats (org.dmg.pmml.DiscrStats)1 Extension (org.dmg.pmml.Extension)1 ModelStats (org.dmg.pmml.ModelStats)1 Predicate (org.dmg.pmml.Predicate)1 SimplePredicate (org.dmg.pmml.SimplePredicate)1 UnivariateStats (org.dmg.pmml.UnivariateStats)1 Test (org.junit.Test)1 KiePMMLException (org.kie.pmml.api.exceptions.KiePMMLException)1 KiePMMLSimpleSetPredicate (org.kie.pmml.commons.model.predicates.KiePMMLSimpleSetPredicate)1